diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 13f5ef4570a7..e2dd3322fdcb 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1276,3 +1276,74 @@ def remap_single_transformer_blocks_(key, state_dict): converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) return converted_state_dict + + +def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict): + # Remove "diffusion_model." prefix from keys. + state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()} + converted_state_dict = {} + + def get_num_layers(keys, pattern): + layers = set() + for key in keys: + match = re.search(pattern, key) + if match: + layers.add(int(match.group(1))) + return len(layers) + + def process_block(prefix, index, convert_norm): + # Process attention qkv: pop lora_A and lora_B weights. + lora_down = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_A.weight") + lora_up = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_B.weight") + for attn_key in ["to_q", "to_k", "to_v"]: + converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_A.weight"] = lora_down + for attn_key, weight in zip(["to_q", "to_k", "to_v"], torch.split(lora_up, [2304, 768, 768], dim=0)): + converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_B.weight"] = weight + + # Process attention out weights. + converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_A.weight"] = state_dict.pop( + f"{prefix}.{index}.attention.out.lora_A.weight" + ) + converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_B.weight"] = state_dict.pop( + f"{prefix}.{index}.attention.out.lora_B.weight" + ) + + # Process feed-forward weights for layers 1, 2, and 3. + for layer in range(1, 4): + converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_A.weight"] = state_dict.pop( + f"{prefix}.{index}.feed_forward.w{layer}.lora_A.weight" + ) + converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_B.weight"] = state_dict.pop( + f"{prefix}.{index}.feed_forward.w{layer}.lora_B.weight" + ) + + if convert_norm: + converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_A.weight"] = state_dict.pop( + f"{prefix}.{index}.adaLN_modulation.1.lora_A.weight" + ) + converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_B.weight"] = state_dict.pop( + f"{prefix}.{index}.adaLN_modulation.1.lora_B.weight" + ) + + noise_refiner_pattern = r"noise_refiner\.(\d+)\." + num_noise_refiner_layers = get_num_layers(state_dict.keys(), noise_refiner_pattern) + for i in range(num_noise_refiner_layers): + process_block("noise_refiner", i, convert_norm=True) + + context_refiner_pattern = r"context_refiner\.(\d+)\." + num_context_refiner_layers = get_num_layers(state_dict.keys(), context_refiner_pattern) + for i in range(num_context_refiner_layers): + process_block("context_refiner", i, convert_norm=False) + + core_transformer_pattern = r"layers\.(\d+)\." + num_core_transformer_layers = get_num_layers(state_dict.keys(), core_transformer_pattern) + for i in range(num_core_transformer_layers): + process_block("layers", i, convert_norm=True) + + if len(state_dict) > 0: + raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}") + + for key in list(converted_state_dict.keys()): + converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) + + return converted_state_dict diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 7802e307c028..d73a41b35e7c 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -41,6 +41,7 @@ _convert_hunyuan_video_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, + _convert_non_diffusers_lumina2_lora_to_diffusers, _convert_xlabs_flux_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers, ) @@ -3815,7 +3816,6 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -3909,6 +3909,11 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + # conversion. + non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict) + if non_diffusers: + state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict) + return state_dict # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights