-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Description
Describe the bug
Why this matters:
The current implementation of attn_procs isn't really compatible with changing/switching attn_procs.
IMO, this would be quite important to add, since the lightweight nature of LoRA layers makes it an ideal candidate to rapidly switch between different "models" on a deployment use case.
Things that currently make this not really possible, all low-hanging fruit:
- LoRA is currently mutually exclusive with xformers and attention slicing. I see there's a LoRAXFormersCrossAttnProcessor implemented, but it doesn't seem to be used anywhere.
- I did some changes to test LoRAXFormersCrossAttnProcessor and it is not working. It's producing a shape mismatch somewhere, while the regular LoRA processor works fine with exactly the same inputs. (see solution below)
- You can't "disable" LoRA after having enabled it. For some reason, the Lora attn_procs inherit from nn.Module, while the other attn_procs do not. Trying to change the attn_proc therefore results in an error. Ex:
cannot assign 'diffusers.models.cross_attention.XFormersCrossAttnProcessor' as child module 'processor' (torch.nn.Module or None expected) - There should be a way of passing arguments to the attn_procs, related to How can I change alpha of LoRA? #2117 .
I managed to get all this working with some hacks/workarounds. Possible solutions:
1 - I'm not sure about the cleanest solution here. One solution would be for the attn_proc loader to be aware of xformers enabled or not, but then the order in which you enable xformers and load LoRA matters. Maybe have a single attention processor that works differently depending on whether xformers has been enabled or not?
2 - There is a bug in the lora xformers attn processor. It's missing hidden_states = attn.batch_to_head_dim(hidden_states) after this line:
| hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) |
3 - I fixed this by wrapping another class around the Lora processor. The loading code has to be adjusted. This allows for the attn_procs to be switched from Lora to non-lora and vice-versa.
class LoRACrossAttnProcessorWrapper:
def __init__(self, attn_proc, scale=1.0):
self.attn_proc = attn_proc
self.scale = scale
def __call__(self, *args, **kwargs):
return self.attn_proc(*args, scale=self.scale, **kwargs)
4 - My workaround to 3 also handles this. Requires the loading code has to be adjusted.
System Info
0.12 master