diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 699395418c51..52970e48147d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -85,12 +85,21 @@ def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank= self.lora_scale = lora_scale + # overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved + # when saving the whole text encoder model and when LoRA is unloaded or fused + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + if self.lora_linear_layer is None: + return self.regular_linear_layer.state_dict( + *args, destination=destination, prefix=prefix, keep_vars=keep_vars + ) + + return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) + def _fuse_lora(self): if self.lora_linear_layer is None: return dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device - logger.info(f"Fusing LoRA weights for {self.__class__}") w_orig = self.regular_linear_layer.weight.data.float() w_up = self.lora_linear_layer.up.weight.data.float() @@ -112,14 +121,14 @@ def _fuse_lora(self): def _unfuse_lora(self): if not (hasattr(self, "w_up") and hasattr(self, "w_down")): return - logger.info(f"Unfusing LoRA weights for {self.__class__}") fused_weight = self.regular_linear_layer.weight.data dtype, device = fused_weight.dtype, fused_weight.device - self.w_up = self.w_up.to(device=device, dtype=dtype) - self.w_down = self.w_down.to(device, dtype=dtype) - unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0] + w_up = self.w_up.to(device=device).float() + w_down = self.w_down.to(device).float() + + unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0] self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype) self.w_up = None @@ -1405,15 +1414,15 @@ def _remove_text_encoder_monkey_patch(self): def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): for _, attn_module in text_encoder_attn_modules(text_encoder): if isinstance(attn_module.q_proj, PatchedLoraProjection): - attn_module.q_proj = attn_module.q_proj.regular_linear_layer - attn_module.k_proj = attn_module.k_proj.regular_linear_layer - attn_module.v_proj = attn_module.v_proj.regular_linear_layer - attn_module.out_proj = attn_module.out_proj.regular_linear_layer + attn_module.q_proj.lora_linear_layer = None + attn_module.k_proj.lora_linear_layer = None + attn_module.v_proj.lora_linear_layer = None + attn_module.out_proj.lora_linear_layer = None for _, mlp_module in text_encoder_mlp_modules(text_encoder): if isinstance(mlp_module.fc1, PatchedLoraProjection): - mlp_module.fc1 = mlp_module.fc1.regular_linear_layer - mlp_module.fc2 = mlp_module.fc2.regular_linear_layer + mlp_module.fc1.lora_linear_layer = None + mlp_module.fc2.lora_linear_layer = None @classmethod def _modify_text_encoder( @@ -1447,23 +1456,43 @@ def _modify_text_encoder( else: current_rank = rank + q_linear_layer = ( + attn_module.q_proj.regular_linear_layer + if isinstance(attn_module.q_proj, PatchedLoraProjection) + else attn_module.q_proj + ) attn_module.q_proj = PatchedLoraProjection( - attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype + q_linear_layer, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype ) lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters()) + k_linear_layer = ( + attn_module.k_proj.regular_linear_layer + if isinstance(attn_module.k_proj, PatchedLoraProjection) + else attn_module.k_proj + ) attn_module.k_proj = PatchedLoraProjection( - attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype + k_linear_layer, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype ) lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters()) + v_linear_layer = ( + attn_module.v_proj.regular_linear_layer + if isinstance(attn_module.v_proj, PatchedLoraProjection) + else attn_module.v_proj + ) attn_module.v_proj = PatchedLoraProjection( - attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype + v_linear_layer, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype ) lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters()) + out_linear_layer = ( + attn_module.out_proj.regular_linear_layer + if isinstance(attn_module.out_proj, PatchedLoraProjection) + else attn_module.out_proj + ) attn_module.out_proj = PatchedLoraProjection( - attn_module.out_proj, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype + out_linear_layer, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype ) lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters()) @@ -1475,13 +1504,23 @@ def _modify_text_encoder( current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight") current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight") + fc1_linear_layer = ( + mlp_module.fc1.regular_linear_layer + if isinstance(mlp_module.fc1, PatchedLoraProjection) + else mlp_module.fc1 + ) mlp_module.fc1 = PatchedLoraProjection( - mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype + fc1_linear_layer, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype ) lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters()) + fc2_linear_layer = ( + mlp_module.fc2.regular_linear_layer + if isinstance(mlp_module.fc2, PatchedLoraProjection) + else mlp_module.fc2 + ) mlp_module.fc2 = PatchedLoraProjection( - mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype + fc2_linear_layer, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype ) lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters()) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index cb0b3b45eb69..671c93a3b2b2 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -168,7 +168,6 @@ def _fuse_lora(self): return dtype, device = self.weight.data.dtype, self.weight.data.device - logger.info(f"Fusing LoRA weights for {self.__class__}") w_orig = self.weight.data.float() w_up = self.lora_layer.up.weight.data.float() @@ -190,14 +189,14 @@ def _fuse_lora(self): def _unfuse_lora(self): if not (hasattr(self, "w_up") and hasattr(self, "w_down")): return - logger.info(f"Unfusing LoRA weights for {self.__class__}") fused_weight = self.weight.data dtype, device = fused_weight.dtype, fused_weight.device - self.w_up = self.w_up.to(device=device, dtype=dtype) - self.w_down = self.w_down.to(device, dtype=dtype) - unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0] + w_up = self.w_up.to(device=device).float() + w_down = self.w_down.to(device).float() + + unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0] self.weight.data = unfused_weight.to(device=device, dtype=dtype) self.w_up = None diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index f17529a1680e..848f2f44adc9 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import os import tempfile import time @@ -100,6 +101,18 @@ def set_lora_weights(lora_attn_parameters, randn_weight=False): torch.zero_(parameter) +def state_dicts_almost_equal(sd1, sd2): + sd1 = dict(sorted(sd1.items())) + sd2 = dict(sorted(sd2.items())) + + models_are_equal = True + for ten1, ten2 in zip(sd1.values(), sd2.values()): + if (ten1 - ten2).abs().sum() > 1e-3: + models_are_equal = False + + return models_are_equal + + class LoraLoaderMixinTests(unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) @@ -674,6 +687,45 @@ def test_load_lora_locally(self): sd_pipe.unload_lora_weights() + def test_text_encoder_lora_state_dict_unchanged(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + + text_encoder_1_sd_keys = sorted(sd_pipe.text_encoder.state_dict().keys()) + text_encoder_2_sd_keys = sorted(sd_pipe.text_encoder_2.state_dict().keys()) + + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + safe_serialization=False, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + + text_encoder_1_sd_keys_2 = sorted(sd_pipe.text_encoder.state_dict().keys()) + text_encoder_2_sd_keys_2 = sorted(sd_pipe.text_encoder_2.state_dict().keys()) + + sd_pipe.unload_lora_weights() + + text_encoder_1_sd_keys_3 = sorted(sd_pipe.text_encoder.state_dict().keys()) + text_encoder_2_sd_keys_3 = sorted(sd_pipe.text_encoder_2.state_dict().keys()) + + # default & unloaded LoRA weights should have identical state_dicts + assert text_encoder_1_sd_keys == text_encoder_1_sd_keys_3 + # default & loaded LoRA weights should NOT have identical state_dicts + assert text_encoder_1_sd_keys != text_encoder_1_sd_keys_2 # + + # default & unloaded LoRA weights should have identical state_dicts + assert text_encoder_2_sd_keys == text_encoder_2_sd_keys_3 + # default & loaded LoRA weights should NOT have identical state_dicts + assert text_encoder_2_sd_keys != text_encoder_2_sd_keys_2 + def test_load_lora_locally_safetensors(self): pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**pipeline_components) @@ -1187,3 +1239,22 @@ def test_sdxl_1_0_last_ben(self): expected = np.array([0.5244, 0.4347, 0.4312, 0.4246, 0.4398, 0.4409, 0.4884, 0.4938, 0.4094]) self.assertTrue(np.allclose(images, expected, atol=1e-3)) + + def test_sdxl_1_0_fuse_unfuse_all(self): + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict()) + text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict()) + unet_sd = copy.deepcopy(pipe.unet.state_dict()) + + pipe.load_lora_weights("davizca87/sun-flower", weight_name="snfw3rXL-000004.safetensors") + pipe.fuse_lora() + pipe.unload_lora_weights() + pipe.unfuse_lora() + + new_text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict()) + new_text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict()) + new_unet_sd = copy.deepcopy(pipe.unet.state_dict()) + + assert state_dicts_almost_equal(text_encoder_1_sd, new_text_encoder_1_sd) + assert state_dicts_almost_equal(text_encoder_2_sd, new_text_encoder_2_sd) + assert state_dicts_almost_equal(unet_sd, new_unet_sd)