From 3f83250831997e983781d0f2b65d3c5b46ddde60 Mon Sep 17 00:00:00 2001 From: Ruoxi Date: Tue, 15 Aug 2023 08:59:16 +0800 Subject: [PATCH 1/8] Implement `CustomDiffusionAttnProcessor2_0` --- src/diffusers/models/attention_processor.py | 111 ++++++++++++++++++++ 1 file changed, 111 insertions(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 43497c2284ac..ea226836932e 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1466,6 +1466,117 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a return hidden_states +class CustomDiffusionAttnProcessor2_0(nn.Module): + r""" + Processor for implementing attention for the Custom Diffusion method. + + Args: + train_kv (`bool`, defaults to `True`): + Whether to newly train the key and value matrices corresponding to the text features. + train_q_out (`bool`, defaults to `True`): + Whether to newly train query matrices corresponding to the latent image features. + hidden_size (`int`, *optional*, defaults to `None`): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*, defaults to `None`): + The number of channels in the `encoder_hidden_states`. + out_bias (`bool`, defaults to `True`): + Whether to include the bias parameter in `train_q_out`. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + """ + + def __init__( + self, + train_kv=True, + train_q_out=True, + hidden_size=None, + cross_attention_dim=None, + out_bias=True, + dropout=0.0, + ): + super().__init__() + self.train_kv = train_kv + self.train_q_out = train_q_out + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + + # `_custom_diffusion` id for easy serialization and loading. + if self.train_kv: + self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + if self.train_q_out: + self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) + self.to_out_custom_diffusion = nn.ModuleList([]) + self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) + self.to_out_custom_diffusion.append(nn.Dropout(dropout)) + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if self.train_q_out: + query = self.to_q_custom_diffusion(hidden_states) + else: + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + crossattn = False + encoder_hidden_states = hidden_states + else: + crossattn = True + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + if self.train_kv: + key = self.to_k_custom_diffusion(encoder_hidden_states) + value = self.to_v_custom_diffusion(encoder_hidden_states) + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if crossattn: + detach = torch.ones_like(key) + detach[:, :1, :] = detach[:, :1, :] * 0.0 + key = detach * key + (1 - detach) * key.detach() + value = detach * value + (1 - detach) * value.detach() + + # query = attn.head_to_batch_dim(query) + # key = attn.head_to_batch_dim(key) + # value = attn.head_to_batch_dim(value) + inner_dim = hidden_states.shape[-1] + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # attention_probs = attn.get_attention_scores(query, key, attention_mask) + # hidden_states = torch.bmm(attention_probs, value) + # hidden_states = attn.batch_to_head_dim(hidden_states) + + if self.train_q_out: + # linear proj + hidden_states = self.to_out_custom_diffusion[0](hidden_states) + # dropout + hidden_states = self.to_out_custom_diffusion[1](hidden_states) + else: + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + class SlicedAttnProcessor: r""" Processor for implementing sliced attention. From f0114676df5361b058330207cf310942355882ca Mon Sep 17 00:00:00 2001 From: Ruoxi Date: Wed, 16 Aug 2023 09:12:05 +0800 Subject: [PATCH 2/8] Doc-strings and type annotations for `CustomDiffusionAttnProcessor2_0`. (#1) * Update attnprocessor.md * Update attention_processor.py --- docs/source/en/api/attnprocessor.md | 5 ++++- src/diffusers/models/attention_processor.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/attnprocessor.md b/docs/source/en/api/attnprocessor.md index 7a4812e0961e..0b11c1f5bc5d 100644 --- a/docs/source/en/api/attnprocessor.md +++ b/docs/source/en/api/attnprocessor.md @@ -17,6 +17,9 @@ An attention processor is a class for applying different types of attention mech ## CustomDiffusionAttnProcessor [[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor +## CustomDiffusionAttnProcessor2_0 +[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor2_0 + ## AttnAddedKVProcessor [[autodoc]] models.attention_processor.AttnAddedKVProcessor @@ -39,4 +42,4 @@ An attention processor is a class for applying different types of attention mech [[autodoc]] models.attention_processor.SlicedAttnProcessor ## SlicedAttnAddedKVProcessor -[[autodoc]] models.attention_processor.SlicedAttnAddedKVProcessor \ No newline at end of file +[[autodoc]] models.attention_processor.SlicedAttnAddedKVProcessor diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ea226836932e..46e9b92ca91c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1468,7 +1468,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a class CustomDiffusionAttnProcessor2_0(nn.Module): r""" - Processor for implementing attention for the Custom Diffusion method. + Processor for implementing attention for the Custom Diffusion method + using PyTorch 2.0’s memory-efficient scaled dot-product attention. Args: train_kv (`bool`, defaults to `True`): @@ -1758,6 +1759,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, LoRAAttnAddedKVProcessor, CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, + CustomDiffusionAttnProcessor2_0, ] LORA_ATTENTION_PROCESSORS = ( From e849af4e00dbeacada38088f50fe12c336013adf Mon Sep 17 00:00:00 2001 From: Ruoxi Date: Wed, 16 Aug 2023 09:15:52 +0800 Subject: [PATCH 3/8] Interops for `CustomDiffusionAttnProcessor2_0`. --- src/diffusers/models/attention_processor.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 46e9b92ca91c..419229c59412 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -171,7 +171,7 @@ def set_use_memory_efficient_attention_xformers( LORA_ATTENTION_PROCESSORS, ) is_custom_diffusion = hasattr(self, "processor") and isinstance( - self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor) + self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0) ) is_added_kv_processor = hasattr(self, "processor") and isinstance( self.processor, @@ -259,7 +259,10 @@ def set_use_memory_efficient_attention_xformers( processor.load_state_dict(self.processor.state_dict()) processor.to(self.processor.to_q_lora.up.weight.device) elif is_custom_diffusion: - processor = CustomDiffusionAttnProcessor( + attn_processor_class = ( + CustomDiffusionAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else CustomDiffusionAttnProcessor + ) + processor = attn_processor_class( train_kv=self.processor.train_kv, train_q_out=self.processor.train_q_out, hidden_size=self.processor.hidden_size, From d68897932ff28e2eb9c878c9224c773f4a966d3d Mon Sep 17 00:00:00 2001 From: Ruoxi Date: Wed, 16 Aug 2023 22:00:36 +0800 Subject: [PATCH 4/8] Formatted `attention_processor.py`. --- src/diffusers/models/attention_processor.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 419229c59412..d3775ad3b61e 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -171,7 +171,8 @@ def set_use_memory_efficient_attention_xformers( LORA_ATTENTION_PROCESSORS, ) is_custom_diffusion = hasattr(self, "processor") and isinstance( - self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0) + self.processor, + (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), ) is_added_kv_processor = hasattr(self, "processor") and isinstance( self.processor, @@ -260,7 +261,9 @@ def set_use_memory_efficient_attention_xformers( processor.to(self.processor.to_q_lora.up.weight.device) elif is_custom_diffusion: attn_processor_class = ( - CustomDiffusionAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else CustomDiffusionAttnProcessor + CustomDiffusionAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else CustomDiffusionAttnProcessor ) processor = attn_processor_class( train_kv=self.processor.train_kv, From c1678ec8bc1dd860888ebc46f30e2e8710968ae4 Mon Sep 17 00:00:00 2001 From: Ruoxi Date: Thu, 17 Aug 2023 08:48:05 +0800 Subject: [PATCH 5/8] Formatted doc-string in `attention_processor.py` --- src/diffusers/models/attention_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d3775ad3b61e..078ee2609d7a 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1474,8 +1474,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a class CustomDiffusionAttnProcessor2_0(nn.Module): r""" - Processor for implementing attention for the Custom Diffusion method - using PyTorch 2.0’s memory-efficient scaled dot-product attention. + Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled + dot-product attention. Args: train_kv (`bool`, defaults to `True`): From f6636adaa3e8915279c1a85304646cc97de2c2d7 Mon Sep 17 00:00:00 2001 From: Ruoxi Date: Thu, 17 Aug 2023 09:09:31 +0800 Subject: [PATCH 6/8] Conditional CustomDiffusion2_0 for training example. --- examples/custom_diffusion/train_custom_diffusion.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index a5b4b0846f26..f84a51e688dd 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -50,7 +50,11 @@ UNet2DConditionModel, ) from diffusers.loaders import AttnProcsLayers -from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor +from diffusers.models.attention_processor import ( + CustomDiffusionAttnProcessor, + CustomDiffusionAttnProcessor2_0, + CustomDiffusionXFormersAttnProcessor, +) from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -859,7 +863,9 @@ def main(args): unet.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) - attention_class = CustomDiffusionAttnProcessor + attention_class = ( + CustomDiffusionAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else CustomDiffusionAttnProcessor + ) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers From fc693fc669b41a50f8c4f6e2acc78a3be1ce9b61 Mon Sep 17 00:00:00 2001 From: Ruoxi Date: Thu, 17 Aug 2023 09:12:12 +0800 Subject: [PATCH 7/8] Remove unnecessary reference impl in comments. --- src/diffusers/models/attention_processor.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 078ee2609d7a..19155833dd1d 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1547,9 +1547,6 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a key = detach * key + (1 - detach) * key.detach() value = detach * value + (1 - detach) * value.detach() - # query = attn.head_to_batch_dim(query) - # key = attn.head_to_batch_dim(key) - # value = attn.head_to_batch_dim(value) inner_dim = hidden_states.shape[-1] head_dim = inner_dim // attn.heads @@ -1566,10 +1563,6 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - # attention_probs = attn.get_attention_scores(query, key, attention_mask) - # hidden_states = torch.bmm(attention_probs, value) - # hidden_states = attn.batch_to_head_dim(hidden_states) - if self.train_q_out: # linear proj hidden_states = self.to_out_custom_diffusion[0](hidden_states) From 19edb3faa9c2bb1d0b0a6a15b3acece3b73442ef Mon Sep 17 00:00:00 2001 From: Ruoxi Date: Thu, 17 Aug 2023 13:34:52 +0800 Subject: [PATCH 8/8] Fix `save_attn_procs`. --- src/diffusers/loaders.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 0907cbfd163a..5f55b8aa7f6e 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -518,6 +518,7 @@ def save_attn_procs( """ from .models.attention_processor import ( CustomDiffusionAttnProcessor, + CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor, ) @@ -537,7 +538,10 @@ def save_function(weights, filename): os.makedirs(save_directory, exist_ok=True) is_custom_diffusion = any( - isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)) + isinstance( + x, + (CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor), + ) for (_, x) in self.attn_processors.items() ) if is_custom_diffusion: @@ -545,7 +549,14 @@ def save_function(weights, filename): { y: x for (y, x) in self.attn_processors.items() - if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)) + if isinstance( + x, + ( + CustomDiffusionAttnProcessor, + CustomDiffusionAttnProcessor2_0, + CustomDiffusionXFormersAttnProcessor, + ), + ) } ) state_dict = model_to_save.state_dict()