Skip to content

NotImplementedError: Only LoRAs with input/output features higher than the current module's input/output features are currently supported #10227

@JakobLS

Description

@JakobLS

Describe the bug

Unable to merge LoRAs with input/output features higher than the current module's input/output features. See the example and logs below.

Reproduction

from diffusers import FluxFillPipeline, FluxTransformer2DModel

transformer = FluxTransformer2DModel.from_pretrained(
    "xiaozaa/catvton-flux-alpha", 
    torch_dtype=torch.bfloat16
)
pipe = FluxFillPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    transformer=transformer,
    torch_dtype=torch.bfloat16
).to("cuda")

adapter_id = "alimama-creative/FLUX.1-Turbo-Alpha"
pipe.load_lora_weights(adapter_id)
pipe.fuse_lora()

Logs

NotImplementedError                       Traceback (most recent call last)
Cell In[7], line 48
     46 # Load and fuse Lora
     47 adapter_id = "alimama-creative/FLUX.1-Turbo-Alpha"
---> 48 pipe.load_lora_weights(adapter_id)
     49 pipe.fuse_lora()

File ~/miniconda3/envs/env-p11/lib/python3.11/site-packages/diffusers/loaders/lora_pipeline.py:1856, in FluxLoraLoaderMixin.load_lora_weights(self, pretrained_model_name_or_path_or_dict, adapter_name, **kwargs)
   1849 transformer_norm_state_dict = {
   1850     k: state_dict.pop(k)
   1851     for k in list(state_dict.keys())
   1852     if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
   1853 }
   1855 transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
-> 1856 has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
   1857     transformer, transformer_lora_state_dict, transformer_norm_state_dict
   1858 )
   1860 if has_param_with_expanded_shape:
   1861     logger.info(
   1862         "The LoRA weights contain parameters that have different shapes that expected by the transformer. "
   1863         "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
   1864         "To get a comprehensive list of parameter names that were modified, enable debug logging."
   1865     )

File ~/miniconda3/envs/env-p11/lib/python3.11/site-packages/diffusers/loaders/lora_pipeline.py:2333, in FluxLoraLoaderMixin._maybe_expand_transformer_param_shape_or_error_(cls, transformer, lora_state_dict, norm_state_dict, prefix)
   2331 module_out_features, module_in_features = module_weight.shape
   2332 if out_features < module_out_features or in_features < module_in_features:
-> 2333     raise NotImplementedError(
   2334         f"Only LoRAs with input/output features higher than the current module's input/output features "
   2335         f"are currently supported. The provided LoRA contains {in_features=} and {out_features=}, which "
   2336         f"are lower than {module_in_features=} and {module_out_features=}. If you require support for "
   2337         f"this please open an issue at https://github.com/huggingface/diffusers/issues."
   2338     )
   2340 logger.debug(
   2341     f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
   2342     f"checkpoint contains higher number of features than expected. The number of input_features will be "
   2343     f"expanded from {module_in_features} to {in_features}, and the number of output features will be "
   2344     f"expanded from {module_out_features} to {out_features}."
   2345 )
   2347 has_param_with_shape_update = True

NotImplementedError: Only LoRAs with input/output features higher than the current module's input/output features are currently supported. The provided LoRA contains in_features=64 and out_features=3072, which are lower than module_in_features=384 and module_out_features=3072. If you require support for this please open an issue at https://github.com/huggingface/diffusers/issues.

System Info

  • 🤗 Diffusers version: 0.32.0.dev0
  • Platform: Linux-5.15.0-1068-aws-x86_64-with-glibc2.31
  • Running on Google Colab?: No
  • Python version: 3.11.11
  • PyTorch version (GPU?): 2.4.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.26.5
  • Transformers version: 4.47.0
  • Accelerate version: 1.2.0
  • PEFT version: 0.14.0
  • Bitsandbytes version: not installed
  • Safetensors version: 0.4.5
  • xFormers version: not installed
  • Accelerator: NVIDIA L40S, 46068 MiB
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

@sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions