Skip to content

Incompatibility between LoRA attn_procs and other attn_procs, and other LoRA inference issues #2124

@jorgemcgomes

Description

@jorgemcgomes

Describe the bug

Why this matters:
The current implementation of attn_procs isn't really compatible with changing/switching attn_procs.
IMO, this would be quite important to add, since the lightweight nature of LoRA layers makes it an ideal candidate to rapidly switch between different "models" on a deployment use case.

Things that currently make this not really possible, all low-hanging fruit:

  1. LoRA is currently mutually exclusive with xformers and attention slicing. I see there's a LoRAXFormersCrossAttnProcessor implemented, but it doesn't seem to be used anywhere.
  2. I did some changes to test LoRAXFormersCrossAttnProcessor and it is not working. It's producing a shape mismatch somewhere, while the regular LoRA processor works fine with exactly the same inputs. (see solution below)
  3. You can't "disable" LoRA after having enabled it. For some reason, the Lora attn_procs inherit from nn.Module, while the other attn_procs do not. Trying to change the attn_proc therefore results in an error. Ex: cannot assign 'diffusers.models.cross_attention.XFormersCrossAttnProcessor' as child module 'processor' (torch.nn.Module or None expected)
  4. There should be a way of passing arguments to the attn_procs, related to How can I change alpha of LoRA? #2117 .

I managed to get all this working with some hacks/workarounds. Possible solutions:

1 - I'm not sure about the cleanest solution here. One solution would be for the attn_proc loader to be aware of xformers enabled or not, but then the order in which you enable xformers and load LoRA matters. Maybe have a single attention processor that works differently depending on whether xformers has been enabled or not?

2 - There is a bug in the lora xformers attn processor. It's missing hidden_states = attn.batch_to_head_dim(hidden_states) after this line:

hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)

3 - I fixed this by wrapping another class around the Lora processor. The loading code has to be adjusted. This allows for the attn_procs to be switched from Lora to non-lora and vice-versa.

        class LoRACrossAttnProcessorWrapper:
            def __init__(self, attn_proc, scale=1.0):
                self.attn_proc = attn_proc
                self.scale = scale

            def __call__(self, *args, **kwargs):
                return self.attn_proc(*args, scale=self.scale, **kwargs)

4 - My workaround to 3 also handles this. Requires the loading code has to be adjusted.

System Info

0.12 master

Metadata

Metadata

Labels

bugSomething isn't workingstaleIssues that haven't received updates

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions