From 2c6c328f7acae5bf14c0869244d594f989141f0b Mon Sep 17 00:00:00 2001 From: maruo <128989064+maruo555@users.noreply.github.com> Date: Sat, 24 May 2025 02:13:27 +0900 Subject: [PATCH] fix_te_mlp_fc_only --- networks/lora.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/networks/lora.py b/networks/lora.py index 1699a60ff..5f47a012e 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -13,6 +13,7 @@ import re from library.utils import setup_logging from library.sdxl_original_unet import SdxlUNet2DConditionModel +import library.maruo_global_config as maruoCfg setup_logging() import logging @@ -867,6 +868,7 @@ class LoRANetwork(torch.nn.Module): UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE_MLPONLY = ["CLIPAttention", "CLIPMLP"] # 昔のバージョンの状態と同じにしたい場合用(実験用) LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" @@ -1029,7 +1031,12 @@ def create_modules( index = None logger.info(f"create LoRA for Text Encoder:") - text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + if maruoCfg.te_mlp_fc_only: + # 改造ルート + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE_MLPONLY) + else: + # 通常ルート + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")