From fb708fba19276c08fdce5849447d725293cbf962 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Tue, 16 May 2023 04:13:34 +0900 Subject: [PATCH 01/11] fix monkey-patch for text_encoder --- src/diffusers/loaders.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e50bc31a5c63..f9840bca626e 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -943,14 +943,16 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): module = self.text_encoder.get_submodule(name) # Construct a new function that performs the LoRA merging. We will monkey patch # this forward pass. - lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) - old_forward = module.forward - def new_forward(x): - return old_forward(x) + lora_layer(x) + if name in attn_processors: + module.lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) + module.old_forward = module.forward - # Monkey-patch. - module.forward = new_forward + def new_forward(self, x): + return self.old_forward(x) + self.lora_layer(x) + + # Monkey-patch. + module.forward = new_forward.__get__(module) def _get_lora_layer_attribute(self, name: str) -> str: if "q_proj" in name: From 6e8f3ab897a6c068b5ac997887cd79dbef6618d0 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 20 May 2023 00:29:14 +0900 Subject: [PATCH 02/11] add test_text_encoder_lora_monkey_patch() --- tests/models/test_lora_layers.py | 62 ++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 6f1e85e15558..ffdc7569d2e1 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -212,3 +212,65 @@ def test_lora_save_load_legacy(self): # Outputs shouldn't match. self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + + # copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb + def get_dummy_tokens(self): + max_seq_length = 77 + + inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)).to("cuda") + + prepared_inputs = {} + prepared_inputs["input_ids"] = inputs + return prepared_inputs + + def test_text_encoder_lora_monkey_patch(self): + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to("cuda") + + dummy_tokens = self.get_dummy_tokens() + + # inference without lora + outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_without_lora.shape == (1, 77, 768) + + text_lora_attn_procs = {} + for name, module in pipe.text_encoder.named_modules(): + if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): + text_lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=module.out_features, cross_attention_dim=None + ).to("cuda") + + # monkey patch + pipe._modify_text_encoder(text_lora_attn_procs) + + # make sure that the lora_up.weights are zeroed out + for name, attn_proc in text_lora_attn_procs.items(): + for n in ["q", "k", "v", "out"]: + n = f"to_{n}_lora" + lora_linear_layer = getattr(attn_proc, n) + lora_up_weight = lora_linear_layer.up.weight + assert torch.allclose( + lora_up_weight, torch.zeros_like(lora_up_weight) + ), "lora_up_weight should be zeroed out" + + # inference with lora + outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_with_lora.shape == (1, 77, 768) + + assert torch.allclose( + outputs_without_lora, outputs_with_lora + ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" + + # make lora_up.weights as random + for name, attn_proc in text_lora_attn_procs.items(): + for n in ["q", "k", "v", "out"]: + n = f"to_{n}_lora" + lora_linear_layer = getattr(attn_proc, n) + lora_linear_layer.up.weight = torch.nn.Parameter(torch.randn_like(lora_linear_layer.up.weight)) + + # inference with lora + outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_with_lora.shape == (1, 77, 768) + + assert not torch.allclose( + outputs_without_lora, outputs_with_lora + ), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs" From 851175565342669deeb59aa95e446f31b4b9b256 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 20 May 2023 01:20:00 +0900 Subject: [PATCH 03/11] verify that it's okay to release the attn_procs --- tests/models/test_lora_layers.py | 42 +++++++++++++++++--------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index ffdc7569d2e1..24043544a74d 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import gc import os import tempfile import unittest @@ -22,7 +23,7 @@ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin -from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.models.attention_processor import LoRAAttnProcessor, LoRALinearLayer from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device @@ -232,25 +233,27 @@ def test_text_encoder_lora_monkey_patch(self): outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] assert outputs_without_lora.shape == (1, 77, 768) + # create lora_attn_procs with zeroed out up.weights text_lora_attn_procs = {} for name, module in pipe.text_encoder.named_modules(): if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): - text_lora_attn_procs[name] = LoRAAttnProcessor( - hidden_size=module.out_features, cross_attention_dim=None - ).to("cuda") + attn_proc = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None).to("cuda") + + # make sure that the up.weights are zeroed out + for layer_name, layer_module in attn_proc.named_modules(): + if layer_name.endswith("_lora"): + assert torch.allclose( + layer_module.up.weight, torch.zeros_like(layer_module.up.weight) + ), "lora_up_weight should be zeroed out" + + text_lora_attn_procs[name] = attn_proc # monkey patch pipe._modify_text_encoder(text_lora_attn_procs) - # make sure that the lora_up.weights are zeroed out - for name, attn_proc in text_lora_attn_procs.items(): - for n in ["q", "k", "v", "out"]: - n = f"to_{n}_lora" - lora_linear_layer = getattr(attn_proc, n) - lora_up_weight = lora_linear_layer.up.weight - assert torch.allclose( - lora_up_weight, torch.zeros_like(lora_up_weight) - ), "lora_up_weight should be zeroed out" + # verify that it's okay to release the text_lora_attn_procs which holds the LoRAAttnProcessor. + del text_lora_attn_procs + gc.collect() # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] @@ -260,12 +263,13 @@ def test_text_encoder_lora_monkey_patch(self): outputs_without_lora, outputs_with_lora ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" - # make lora_up.weights as random - for name, attn_proc in text_lora_attn_procs.items(): - for n in ["q", "k", "v", "out"]: - n = f"to_{n}_lora" - lora_linear_layer = getattr(attn_proc, n) - lora_linear_layer.up.weight = torch.nn.Parameter(torch.randn_like(lora_linear_layer.up.weight)) + # set randn to lora_up.weights + for name, _ in pipe.text_encoder.named_modules(): + if any(name.endswith(x) for x in TEXT_ENCODER_TARGET_MODULES): + module = pipe.text_encoder.get_submodule(name) + assert hasattr(module, "lora_layer"), "lora_layer should be added" + assert isinstance(module.lora_layer, LoRALinearLayer), "lora_layer should be LoRALinearLayer" + module.lora_layer.up.weight = torch.nn.Parameter(torch.randn_like(module.lora_layer.up.weight)) # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] From 81915f48dfff3cd2e2654bc820088572f4e8f5db Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 20 May 2023 03:47:58 +0900 Subject: [PATCH 04/11] fix closure version --- src/diffusers/loaders.py | 15 ++++----- tests/models/test_lora_layers.py | 53 ++++++++++++++++++-------------- 2 files changed, 38 insertions(+), 30 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index f9840bca626e..ad1096f65c21 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -943,16 +943,17 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): module = self.text_encoder.get_submodule(name) # Construct a new function that performs the LoRA merging. We will monkey patch # this forward pass. + lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) + old_forward = module.forward - if name in attn_processors: - module.lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) - module.old_forward = module.forward + def make_new_forward(old_forward, lora_layer): + def new_forward(x): + return old_forward(x) + lora_layer(x) - def new_forward(self, x): - return self.old_forward(x) + self.lora_layer(x) + return new_forward - # Monkey-patch. - module.forward = new_forward.__get__(module) + # Monkey-patch. + module.forward = make_new_forward(old_forward, lora_layer) def _get_lora_layer_attribute(self, name: str) -> str: if "q_proj" in name: diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 24043544a74d..6cf79a0c11cb 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -23,7 +23,7 @@ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin -from diffusers.models.attention_processor import LoRAAttnProcessor, LoRALinearLayer +from diffusers.models.attention_processor import LoRAAttnProcessor from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device @@ -218,14 +218,31 @@ def test_lora_save_load_legacy(self): def get_dummy_tokens(self): max_seq_length = 77 - inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)).to("cuda") + inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)) prepared_inputs = {} prepared_inputs["input_ids"] = inputs return prepared_inputs + def get_text_lora_attn_procs(self, text_encoder: nn.Module, randn_weight=False): + text_lora_attn_procs = {} + for name, module in text_encoder.named_modules(): + if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): + attn_proc = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) + # set up.weights + for layer_name, layer_module in attn_proc.named_modules(): + if layer_name.endswith("_lora"): + weight = ( + torch.randn_like(layer_module.up.weight) + if randn_weight + else torch.zeros_like(layer_module.up.weight) + ) + layer_module.up.weight = torch.nn.Parameter(weight) + text_lora_attn_procs[name] = attn_proc + return text_lora_attn_procs + def test_text_encoder_lora_monkey_patch(self): - pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to("cuda") + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") dummy_tokens = self.get_dummy_tokens() @@ -234,19 +251,7 @@ def test_text_encoder_lora_monkey_patch(self): assert outputs_without_lora.shape == (1, 77, 768) # create lora_attn_procs with zeroed out up.weights - text_lora_attn_procs = {} - for name, module in pipe.text_encoder.named_modules(): - if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): - attn_proc = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None).to("cuda") - - # make sure that the up.weights are zeroed out - for layer_name, layer_module in attn_proc.named_modules(): - if layer_name.endswith("_lora"): - assert torch.allclose( - layer_module.up.weight, torch.zeros_like(layer_module.up.weight) - ), "lora_up_weight should be zeroed out" - - text_lora_attn_procs[name] = attn_proc + text_lora_attn_procs = self.get_text_lora_attn_procs(pipe.text_encoder, randn_weight=False) # monkey patch pipe._modify_text_encoder(text_lora_attn_procs) @@ -263,13 +268,15 @@ def test_text_encoder_lora_monkey_patch(self): outputs_without_lora, outputs_with_lora ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" - # set randn to lora_up.weights - for name, _ in pipe.text_encoder.named_modules(): - if any(name.endswith(x) for x in TEXT_ENCODER_TARGET_MODULES): - module = pipe.text_encoder.get_submodule(name) - assert hasattr(module, "lora_layer"), "lora_layer should be added" - assert isinstance(module.lora_layer, LoRALinearLayer), "lora_layer should be LoRALinearLayer" - module.lora_layer.up.weight = torch.nn.Parameter(torch.randn_like(module.lora_layer.up.weight)) + # create lora_attn_procs with randn up.weights + text_lora_attn_procs = self.get_text_lora_attn_procs(pipe.text_encoder, randn_weight=True) + + # monkey patch + pipe._modify_text_encoder(text_lora_attn_procs) + + # verify that it's okay to release the text_lora_attn_procs which holds the LoRAAttnProcessor. + del text_lora_attn_procs + gc.collect() # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] From 88db546c01eff271025ab1581f467f41be337c3f Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 20 May 2023 03:53:05 +0900 Subject: [PATCH 05/11] add comment --- src/diffusers/loaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index ad1096f65c21..7eb389184ed9 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -946,6 +946,7 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) old_forward = module.forward + # create a new scope that locks in the old_forward, lora_layer value for each new_forward function def make_new_forward(old_forward, lora_layer): def new_forward(x): return old_forward(x) + lora_layer(x) From 1da772b9fe8f64702989c1319348dfffd65dc491 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Tue, 23 May 2023 00:02:01 +0900 Subject: [PATCH 06/11] Fix to reuse utility functions --- tests/models/test_lora_layers.py | 64 +++++++++++++++++--------------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 6cf79a0c11cb..528c6e8bc35a 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -44,15 +44,33 @@ def create_unet_lora_layers(unet: nn.Module): return lora_attn_procs, unet_lora_layers -def create_text_encoder_lora_layers(text_encoder: nn.Module): +def create_text_encoder_lora_attn_procs(text_encoder: nn.Module): text_lora_attn_procs = {} for name, module in text_encoder.named_modules(): if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) + return text_lora_attn_procs + + +def create_text_encoder_lora_layers(text_encoder: nn.Module): + text_lora_attn_procs = create_text_encoder_lora_attn_procs(text_encoder) text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) return text_encoder_lora_layers +def set_lora_up_weights(text_lora_attn_procs, randn_weight=False): + for _, attn_proc in text_lora_attn_procs.items(): + # set up.weights + for layer_name, layer_module in attn_proc.named_modules(): + if layer_name.endswith("_lora"): + weight = ( + torch.randn_like(layer_module.up.weight) + if randn_weight + else torch.zeros_like(layer_module.up.weight) + ) + layer_module.up.weight = torch.nn.Parameter(weight) + + class LoraLoaderMixinTests(unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) @@ -224,63 +242,49 @@ def get_dummy_tokens(self): prepared_inputs["input_ids"] = inputs return prepared_inputs - def get_text_lora_attn_procs(self, text_encoder: nn.Module, randn_weight=False): - text_lora_attn_procs = {} - for name, module in text_encoder.named_modules(): - if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): - attn_proc = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) - # set up.weights - for layer_name, layer_module in attn_proc.named_modules(): - if layer_name.endswith("_lora"): - weight = ( - torch.randn_like(layer_module.up.weight) - if randn_weight - else torch.zeros_like(layer_module.up.weight) - ) - layer_module.up.weight = torch.nn.Parameter(weight) - text_lora_attn_procs[name] = attn_proc - return text_lora_attn_procs - def test_text_encoder_lora_monkey_patch(self): - pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + pipeline_components, _ = self.get_dummy_components() + pipe = StableDiffusionPipeline(**pipeline_components) dummy_tokens = self.get_dummy_tokens() # inference without lora outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] - assert outputs_without_lora.shape == (1, 77, 768) + assert outputs_without_lora.shape == (1, 77, 32) # create lora_attn_procs with zeroed out up.weights - text_lora_attn_procs = self.get_text_lora_attn_procs(pipe.text_encoder, randn_weight=False) + text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) + set_lora_up_weights(text_attn_procs, randn_weight=False) # monkey patch - pipe._modify_text_encoder(text_lora_attn_procs) + pipe._modify_text_encoder(text_attn_procs) - # verify that it's okay to release the text_lora_attn_procs which holds the LoRAAttnProcessor. - del text_lora_attn_procs + # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. + del text_attn_procs gc.collect() # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] - assert outputs_with_lora.shape == (1, 77, 768) + assert outputs_with_lora.shape == (1, 77, 32) assert torch.allclose( outputs_without_lora, outputs_with_lora ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" # create lora_attn_procs with randn up.weights - text_lora_attn_procs = self.get_text_lora_attn_procs(pipe.text_encoder, randn_weight=True) + text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) + set_lora_up_weights(text_attn_procs, randn_weight=True) # monkey patch - pipe._modify_text_encoder(text_lora_attn_procs) + pipe._modify_text_encoder(text_attn_procs) - # verify that it's okay to release the text_lora_attn_procs which holds the LoRAAttnProcessor. - del text_lora_attn_procs + # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. + del text_attn_procs gc.collect() # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] - assert outputs_with_lora.shape == (1, 77, 768) + assert outputs_with_lora.shape == (1, 77, 32) assert not torch.allclose( outputs_without_lora, outputs_with_lora From 8a26848d62cc43b71706d1f7028de5771d35d760 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Tue, 23 May 2023 00:47:17 +0900 Subject: [PATCH 07/11] make LoRAAttnProcessor targets to self_attn --- src/diffusers/loaders.py | 4 +++- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/constants.py | 1 + tests/models/test_lora_layers.py | 8 +++++--- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 7eb389184ed9..5e9e96cbde0d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -33,6 +33,7 @@ DIFFUSERS_CACHE, HF_HUB_OFFLINE, TEXT_ENCODER_TARGET_MODULES, + TEXT_ENCODER_ATTN_MODULE, _get_model_file, deprecate, is_safetensors_available, @@ -943,7 +944,8 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): module = self.text_encoder.get_submodule(name) # Construct a new function that performs the LoRA merging. We will monkey patch # this forward pass. - lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) + attn_processor_name = ".".join(name.split(".")[:-1]) + lora_layer = getattr(attn_processors[attn_processor_name], self._get_lora_layer_attribute(name)) old_forward = module.forward # create a new scope that locks in the old_forward, lora_layer value for each new_forward function diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index cd3a1b8f3dd4..772c36b1177b 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -30,6 +30,7 @@ ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, + TEXT_ENCODER_ATTN_MODULE, TEXT_ENCODER_TARGET_MODULES, WEIGHTS_NAME, ) diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 1134ba6fb656..93d5c8cc42cd 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -31,3 +31,4 @@ HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"] +TEXT_ENCODER_ATTN_MODULE = ".self_attn" diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 528c6e8bc35a..1c7e07744cd2 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -24,7 +24,7 @@ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin from diffusers.models.attention_processor import LoRAAttnProcessor -from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device +from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, floats_tensor, torch_device def create_unet_lora_layers(unet: nn.Module): @@ -47,8 +47,10 @@ def create_unet_lora_layers(unet: nn.Module): def create_text_encoder_lora_attn_procs(text_encoder: nn.Module): text_lora_attn_procs = {} for name, module in text_encoder.named_modules(): - if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): - text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) + if name.endswith(TEXT_ENCODER_ATTN_MODULE): + text_lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=module.out_proj.out_features, cross_attention_dim=None + ) return text_lora_attn_procs From 28c69eefe7f80aced892a1b715f3f913d69c124f Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Tue, 23 May 2023 01:13:10 +0900 Subject: [PATCH 08/11] fix LoRAAttnProcessor target --- examples/dreambooth/train_dreambooth_lora.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index e640542e36da..ceb360138f13 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -58,7 +58,7 @@ SlicedAttnAddedKVProcessor, ) from diffusers.optimization import get_scheduler -from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available +from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -839,9 +839,9 @@ def main(args): if args.train_text_encoder: text_lora_attn_procs = {} for name, module in text_encoder.named_modules(): - if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): + if name.endswith(TEXT_ENCODER_ATTN_MODULE): text_lora_attn_procs[name] = LoRAAttnProcessor( - hidden_size=module.out_features, cross_attention_dim=None + hidden_size=module.out_proj.out_features, cross_attention_dim=None ) text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) temp_pipeline = StableDiffusionPipeline.from_pretrained( From 3a74c7e6d6496351a40cd47c028abde31244991b Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Tue, 23 May 2023 01:40:47 +0900 Subject: [PATCH 09/11] make style --- src/diffusers/loaders.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 5e9e96cbde0d..64a0e942fc77 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -33,7 +33,6 @@ DIFFUSERS_CACHE, HF_HUB_OFFLINE, TEXT_ENCODER_TARGET_MODULES, - TEXT_ENCODER_ATTN_MODULE, _get_model_file, deprecate, is_safetensors_available, From 160a4d356f2b08df4171d15641b1e38389178496 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Wed, 24 May 2023 00:47:59 +0900 Subject: [PATCH 10/11] fix split key --- src/diffusers/loaders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 64a0e942fc77..6255ff89d5c9 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -70,8 +70,8 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]): self.mapping = dict(enumerate(state_dict.keys())) self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} - # .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder - self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"] + # .processor for unet, .self_attn for text encoder + self.split_keys = [".processor", ".self_attn"] # we add a hook to state_dict() and load_state_dict() so that the # naming fits with `unet.attn_processors` From f14329d26351a3361afb16b42f4a0e34ab3da3c3 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Wed, 24 May 2023 00:58:05 +0900 Subject: [PATCH 11/11] Update src/diffusers/loaders.py --- src/diffusers/loaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 6255ff89d5c9..3a3db83f62da 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -948,6 +948,7 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): old_forward = module.forward # create a new scope that locks in the old_forward, lora_layer value for each new_forward function + # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060 def make_new_forward(old_forward, lora_layer): def new_forward(x): return old_forward(x) + lora_layer(x)