diff --git a/scripts/convert_kandinsky_to_diffusers.py b/scripts/convert_kandinsky_to_diffusers.py index 7e22f0559619..0462772e3140 100644 --- a/scripts/convert_kandinsky_to_diffusers.py +++ b/scripts/convert_kandinsky_to_diffusers.py @@ -1,11 +1,13 @@ import argparse import tempfile +import os import torch from accelerate import load_checkpoint_and_dispatch from diffusers import UNet2DConditionModel from diffusers.models.prior_transformer import PriorTransformer +from diffusers.models.vq_model import VQModel from diffusers.pipelines.kandinsky.text_proj import KandinskyTextProjModel @@ -901,15 +903,413 @@ def inpaint_text2img(*, args, checkpoint_map_location): return inpaint_unet_model, text_proj_model +# movq + +MOVQ_CONFIG ={ + "in_channels":3, + "out_channels":3, + "latent_channels":4, + "down_block_types":("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "AttnDownEncoderBlock2D"), + "up_block_types":("AttnUpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"), + "num_vq_embeddings":16384, + "block_out_channels":(128, 256, 256, 512), + "vq_embed_dim":4, + "layers_per_block":2, + "norm_type":"spatial" + } + + +def movq_model_from_original_config(): + movq = VQModel(**MOVQ_CONFIG ) + return movq + +def movq_encoder_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # conv_in + diffusers_checkpoint.update( + { + "encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"], + "encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"], + } + ) + + # down_blocks + for down_block_idx, down_block in enumerate(model.encoder.down_blocks): + diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}" + down_block_prefix = f"encoder.down.{down_block_idx}" + + # resnets + for resnet_idx, resnet in enumerate(down_block.resnets): + diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}" + resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}" + + diffusers_checkpoint.update( + movq_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + # downsample + + # do not include the downsample when on the last down block + # There is no downsample on the last down block + if down_block_idx != len(model.encoder.down_blocks) - 1: + # There's a single downsample in the original checkpoint but a list of downsamples + # in the diffusers model. + diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv" + downsample_prefix = f"{down_block_prefix}.downsample.conv" + diffusers_checkpoint.update( + { + f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], + f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], + } + ) + + # attentions + + if hasattr(down_block, "attentions"): + for attention_idx, _ in enumerate(down_block.attentions): + diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}" + attention_prefix = f"{down_block_prefix}.attn.{attention_idx}" + diffusers_checkpoint.update( + movq_attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix, + ) + ) + + # mid block + + # mid block attentions + + # There is a single hardcoded attention block in the middle of the VQ-diffusion encoder + diffusers_attention_prefix = "encoder.mid_block.attentions.0" + attention_prefix = "encoder.mid.attn_1" + diffusers_checkpoint.update( + movq_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # mid block resnets + + for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): + diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}" + + # the hardcoded prefixes to `block_` are 1 and 2 + orig_resnet_idx = diffusers_resnet_idx + 1 + # There are two hardcoded resnets in the middle of the VQ-diffusion encoder + resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}" + + diffusers_checkpoint.update( + movq_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + diffusers_checkpoint.update( + { + # conv_norm_out + "encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"], + "encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"], + # conv_out + "encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"], + "encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"], + } + ) + + return diffusers_checkpoint + + +def movq_decoder_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # conv in + diffusers_checkpoint.update( + { + "decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"], + "decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"], + } + ) + + # up_blocks + + for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks): + # up_blocks are stored in reverse order in the VQ-diffusion checkpoint + orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx + + diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}" + up_block_prefix = f"decoder.up.{orig_up_block_idx}" + + # resnets + for resnet_idx, resnet in enumerate(up_block.resnets): + diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}" + resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}" + + diffusers_checkpoint.update( + movq_resnet_to_diffusers_checkpoint_spatial_norm( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + # upsample + + # there is no up sample on the last up block + if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1: + # There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples + # in the diffusers model. + diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv" + downsample_prefix = f"{up_block_prefix}.upsample.conv" + diffusers_checkpoint.update( + { + f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], + f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], + } + ) + + # attentions + + if hasattr(up_block, "attentions"): + for attention_idx, _ in enumerate(up_block.attentions): + diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}" + attention_prefix = f"{up_block_prefix}.attn.{attention_idx}" + diffusers_checkpoint.update( + movq_attention_to_diffusers_checkpoint_spatial_norm( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix, + ) + ) + + # mid block + + # mid block attentions + + # There is a single hardcoded attention block in the middle of the VQ-diffusion decoder + diffusers_attention_prefix = "decoder.mid_block.attentions.0" + attention_prefix = "decoder.mid.attn_1" + diffusers_checkpoint.update( + movq_attention_to_diffusers_checkpoint_spatial_norm( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # mid block resnets + + for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): + diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}" + + # the hardcoded prefixes to `block_` are 1 and 2 + orig_resnet_idx = diffusers_resnet_idx + 1 + # There are two hardcoded resnets in the middle of the VQ-diffusion decoder + resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}" + + diffusers_checkpoint.update( + movq_resnet_to_diffusers_checkpoint_spatial_norm( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + diffusers_checkpoint.update( + { + # conv_norm_out + "decoder.conv_norm_out.norm_layer.weight": checkpoint["decoder.norm_out.norm_layer.weight"], + "decoder.conv_norm_out.norm_layer.bias": checkpoint["decoder.norm_out.norm_layer.bias"], + "decoder.conv_norm_out.conv_y.weight": checkpoint["decoder.norm_out.conv_y.weight"], + "decoder.conv_norm_out.conv_y.bias": checkpoint["decoder.norm_out.conv_y.bias"], + "decoder.conv_norm_out.conv_b.weight": checkpoint["decoder.norm_out.conv_b.weight"], + "decoder.conv_norm_out.conv_b.bias": checkpoint["decoder.norm_out.conv_b.bias"], + # conv_out + "decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"], + "decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"], + } + ) + + return diffusers_checkpoint + + +def movq_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + rv = { + # norm1 + f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"], + f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"], + # conv1 + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], + # norm2 + f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"], + f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"], + # conv2 + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], + } + + if resnet.conv_shortcut is not None: + rv.update( + { + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], + f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], + } + ) + + return rv + +def movq_resnet_to_diffusers_checkpoint_spatial_norm(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + rv = { + # norm1 + f"{diffusers_resnet_prefix}.norm1.norm_layer.weight": checkpoint[f"{resnet_prefix}.norm1.norm_layer.weight"], + f"{diffusers_resnet_prefix}.norm1.norm_layer.bias": checkpoint[f"{resnet_prefix}.norm1.norm_layer.bias"], + f"{diffusers_resnet_prefix}.norm1.conv_y.weight": checkpoint[f"{resnet_prefix}.norm1.conv_y.weight"], + f"{diffusers_resnet_prefix}.norm1.conv_y.bias": checkpoint[f"{resnet_prefix}.norm1.conv_y.bias"], + f"{diffusers_resnet_prefix}.norm1.conv_b.weight": checkpoint[f"{resnet_prefix}.norm1.conv_b.weight"], + f"{diffusers_resnet_prefix}.norm1.conv_b.bias": checkpoint[f"{resnet_prefix}.norm1.conv_b.bias"], + # conv1 + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], + # norm2 + f"{diffusers_resnet_prefix}.norm2.norm_layer.weight": checkpoint[f"{resnet_prefix}.norm2.norm_layer.weight"], + f"{diffusers_resnet_prefix}.norm2.norm_layer.bias": checkpoint[f"{resnet_prefix}.norm2.norm_layer.bias"], + f"{diffusers_resnet_prefix}.norm2.conv_y.weight": checkpoint[f"{resnet_prefix}.norm2.conv_y.weight"], + f"{diffusers_resnet_prefix}.norm2.conv_y.bias": checkpoint[f"{resnet_prefix}.norm2.conv_y.bias"], + f"{diffusers_resnet_prefix}.norm2.conv_b.weight": checkpoint[f"{resnet_prefix}.norm2.conv_b.weight"], + f"{diffusers_resnet_prefix}.norm2.conv_b.bias": checkpoint[f"{resnet_prefix}.norm2.conv_b.bias"], + # conv2 + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], + } + + if resnet.conv_shortcut is not None: + rv.update( + { + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], + f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], + } + ) + + return rv + + + +def movq_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): + return { + # norm + f"{diffusers_attention_prefix}.norm.weight": checkpoint[f"{attention_prefix}.norm.weight"], + f"{diffusers_attention_prefix}.norm.bias": checkpoint[f"{attention_prefix}.norm.bias"], + # query + f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"], + # key + f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"], + # value + f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"], + # proj_attn + f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ + :, :, 0, 0 + ], + f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], + } + +def movq_attention_to_diffusers_checkpoint_spatial_norm(checkpoint, *, diffusers_attention_prefix, attention_prefix): + return { + # norm + f"{diffusers_attention_prefix}.norm.norm_layer.weight": checkpoint[f"{attention_prefix}.norm.norm_layer.weight"], + f"{diffusers_attention_prefix}.norm.norm_layer.bias": checkpoint[f"{attention_prefix}.norm.norm_layer.bias"], + f"{diffusers_attention_prefix}.norm.conv_y.weight": checkpoint[f"{attention_prefix}.norm.conv_y.weight"], + f"{diffusers_attention_prefix}.norm.conv_y.bias": checkpoint[f"{attention_prefix}.norm.conv_y.bias"], + f"{diffusers_attention_prefix}.norm.conv_b.weight": checkpoint[f"{attention_prefix}.norm.conv_b.weight"], + f"{diffusers_attention_prefix}.norm.conv_b.bias": checkpoint[f"{attention_prefix}.norm.conv_b.bias"], + # query + f"{diffusers_attention_prefix}.attention.to_q.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.attention.to_q.bias": checkpoint[f"{attention_prefix}.q.bias"], + # key + f"{diffusers_attention_prefix}.attention.to_k.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.attention.to_k.bias": checkpoint[f"{attention_prefix}.k.bias"], + # value + f"{diffusers_attention_prefix}.attention.to_v.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.attention.to_v.bias": checkpoint[f"{attention_prefix}.v.bias"], + # proj_attn + f"{diffusers_attention_prefix}.attention.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ + :, :, 0, 0 + ], + f"{diffusers_attention_prefix}.attention.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], + } + + + + + +def movq_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + diffusers_checkpoint.update(movq_encoder_to_diffusers_checkpoint(model, checkpoint)) + + + # quant_conv + + diffusers_checkpoint.update( + { + "quant_conv.weight": checkpoint["quant_conv.weight"], + "quant_conv.bias": checkpoint["quant_conv.bias"], + } + ) + + # quantize + diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding.weight"]}) + + # post_quant_conv + diffusers_checkpoint.update( + { + "post_quant_conv.weight": checkpoint["post_quant_conv.weight"], + "post_quant_conv.bias": checkpoint["post_quant_conv.bias"], + } + ) + + # decoder + diffusers_checkpoint.update(movq_decoder_to_diffusers_checkpoint(model, checkpoint)) + + + + for keys in diffusers_checkpoint.keys(): + print(keys) + + return diffusers_checkpoint + + + + + +def movq(*, args, checkpoint_map_location): + print("loading movq") + + movq_checkpoint = torch.load(args.movq_checkpoint_path, map_location=checkpoint_map_location) + + movq_model = movq_model_from_original_config() + + movq_diffusers_checkpoint = movq_original_checkpoint_to_diffusers_checkpoint( + movq_model, movq_checkpoint + ) + + del movq_checkpoint + + load_checkpoint_to_model(movq_diffusers_checkpoint, movq_model, strict=True) + + print("done loading movq") + + return movq_model + def load_checkpoint_to_model(checkpoint, model, strict=False): - with tempfile.NamedTemporaryFile() as file: + with tempfile.NamedTemporaryFile(delete=False) as file: torch.save(checkpoint, file.name) del checkpoint if strict: model.load_state_dict(torch.load(file.name), strict=True) else: load_checkpoint_and_dispatch(model, file.name, device_map="auto") + os.remove(file.name) if __name__ == "__main__": @@ -921,24 +1321,31 @@ def load_checkpoint_to_model(checkpoint, model, strict=False): "--prior_checkpoint_path", default=None, type=str, - required=True, + required=False, help="Path to the prior checkpoint to convert.", ) parser.add_argument( - "--clip_stat_path", default=None, type=str, required=True, help="Path to the clip stats checkpoint to convert." + "--clip_stat_path", default=None, type=str, required=False, help="Path to the clip stats checkpoint to convert." ) parser.add_argument( "--text2img_checkpoint_path", default=None, type=str, - required=True, + required=False, + help="Path to the text2img checkpoint to convert.", + ) + parser.add_argument( + "--movq_checkpoint_path", + default=None, + type=str, + required=False, help="Path to the text2img checkpoint to convert.", ) parser.add_argument( "--inpaint_text2img_checkpoint_path", default=None, type=str, - required=True, + required=False, help="Path to the inpaint text2img checkpoint to convert.", ) parser.add_argument( @@ -979,5 +1386,8 @@ def load_checkpoint_to_model(checkpoint, model, strict=False): inpaint_unet_model, inpaint_text_proj_model = inpaint_text2img(args=args, checkpoint_map_location=checkpoint_map_location) inpaint_unet_model.save_pretrained(f"{args.dump_path}/inpaint_unet") inpaint_text_proj_model.save_pretrained(f"{args.dump_path}/inpaint_text_proj") + elif args.debug == 'decoder': + decoder = movq(args=args, checkpoint_map_location=checkpoint_map_location) + decoder.save_pretrained(f"{args.dump_path}/decoder") else: raise ValueError(f"unknown debug value : {args.debug}") diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 0b313b83d360..2b31fa9e2f38 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -369,3 +369,24 @@ def forward(self, x, emb): x = F.group_norm(x, self.num_groups, eps=self.eps) x = x * (1 + scale) + shift return x + +class SpatialNorm(nn.Module): + """ + Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002 + """ + def __init__( + self, + f_channels, + zq_channels, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels,num_groups=32,eps=1e-6,affine=True) + self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, f, zq): + f_size = f.shape[-2:] + zq = F.interpolate(zq, size=f_size, mode="nearest") + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f \ No newline at end of file diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index d9d539959c09..83bec9a52593 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -20,7 +20,7 @@ import torch.nn as nn import torch.nn.functional as F -from .attention import AdaGroupNorm +from .attention import AdaGroupNorm, SpatialNorm class Upsample1D(nn.Module): @@ -460,7 +460,7 @@ def __init__( eps=1e-6, non_linearity="swish", skip_time_act=False, - time_embedding_norm="default", # default, scale_shift, ada_group + time_embedding_norm="default", # default, scale_shift, ada_group, spatial kernel=None, output_scale_factor=1.0, use_in_shortcut=None, @@ -487,6 +487,8 @@ def __init__( if self.time_embedding_norm == "ada_group": self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) + elif self.time_embedding_norm == "spatial": + self.norm1 = SpatialNorm(in_channels, temb_channels) else: self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) @@ -497,7 +499,7 @@ def __init__( self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) elif self.time_embedding_norm == "scale_shift": self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) - elif self.time_embedding_norm == "ada_group": + elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": self.time_emb_proj = None else: raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") @@ -506,6 +508,8 @@ def __init__( if self.time_embedding_norm == "ada_group": self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) + elif self.time_embedding_norm == "spatial": + self.norm2 = SpatialNorm(out_channels, temb_channels) else: self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) @@ -551,7 +555,7 @@ def __init__( def forward(self, input_tensor, temb): hidden_states = input_tensor - if self.time_embedding_norm == "ada_group": + if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": hidden_states = self.norm1(hidden_states, temb) else: hidden_states = self.norm1(hidden_states) @@ -579,7 +583,7 @@ def forward(self, input_tensor, temb): if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb - if self.time_embedding_norm == "ada_group": + if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": hidden_states = self.norm2(hidden_states, temb) else: hidden_states = self.norm2(hidden_states) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 0004f074c563..f91132a61397 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from torch import nn -from .attention import AdaGroupNorm +from .attention import AdaGroupNorm, AttentionBlock, SpatialNorm from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from .dual_transformer_2d import DualTransformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D @@ -348,6 +348,7 @@ def get_up_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels ) elif up_block_type == "AttnUpDecoderBlock2D": return AttnUpDecoderBlock2D( @@ -360,6 +361,7 @@ def get_up_block( resnet_groups=resnet_groups, attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels ) elif up_block_type == "KUpBlock2D": return KUpBlock2D( @@ -406,7 +408,6 @@ def __init__( super().__init__() resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.add_attention = add_attention - # there is always at least one resnet resnets = [ ResnetBlock2D( @@ -439,7 +440,6 @@ def __init__( upcast_softmax=True, _from_deprecated_attn_block=True, ) - ) else: attentions.append(None) @@ -465,7 +465,8 @@ def forward(self, hidden_states, temb=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: - hidden_states = attn(hidden_states) + hidden_states = attn(hidden_states, temb) + hidden_states = resnet(hidden_states, temb) return hidden_states @@ -1971,6 +1972,30 @@ def custom_forward(*inputs): return hidden_states +class MOVQAttention(nn.Module): + def __init__(self, query_dim, temb_channels, attn_num_head_channels): + super().__init__() + + self.norm = SpatialNorm(query_dim, temb_channels) + num_heads = query_dim // attn_num_head_channels if attn_num_head_channels is not None else 1 + dim_head = attn_num_head_channels if attn_num_head_channels is not None else query_dim + self.attention = Attention( + query_dim=query_dim, + heads=num_heads, + dim_head=dim_head, + bias=True + ) + + def forward(self, hidden_states, temb): + residual = hidden_states + hidden_states = self.norm(hidden_states, temb).view(hidden_states.shape[0], hidden_states.shape[1], -1) + hidden_states = self.attention(hidden_states.transpose(1, 2), None, None).transpose(1, 2) + hidden_states = hidden_states.view(residual.shape) + hidden_states = hidden_states + residual + return hidden_states + + + class UpDecoderBlock2D(nn.Module): def __init__( self, @@ -1985,6 +2010,7 @@ def __init__( resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, + temb_channels=None ): super().__init__() resnets = [] @@ -1996,7 +2022,7 @@ def __init__( ResnetBlock2D( in_channels=input_channels, out_channels=out_channels, - temb_channels=None, + temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, @@ -2014,9 +2040,9 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states): + def forward(self, hidden_states, temb=None): for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=None) + hidden_states = resnet(hidden_states, temb=temb) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -2040,6 +2066,7 @@ def __init__( attn_num_head_channels=1, output_scale_factor=1.0, add_upsample=True, + temb_channels=None ): super().__init__() resnets = [] @@ -2052,7 +2079,7 @@ def __init__( ResnetBlock2D( in_channels=input_channels, out_channels=out_channels, - temb_channels=None, + temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, @@ -2075,7 +2102,6 @@ def __init__( upcast_softmax=True, _from_deprecated_attn_block=True, ) - ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) @@ -2085,10 +2111,10 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states): + def forward(self, hidden_states, temb=None): for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=None) - hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states, temb=temb) + hidden_states = attn(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -2847,3 +2873,4 @@ def forward( hidden_states = attn_output + hidden_states return hidden_states + diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 400c3030af90..776203042e9b 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -20,7 +20,7 @@ from ..utils import BaseOutput, randn_tensor from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block - +from .attention import SpatialNorm @dataclass class DecoderOutput(BaseOutput): @@ -149,6 +149,7 @@ def __init__( layers_per_block=2, norm_num_groups=32, act_fn="silu", + norm_type="default", # default, spatial ): super().__init__() self.layers_per_block = layers_per_block @@ -164,16 +165,19 @@ def __init__( self.mid_block = None self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + # mid self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, - resnet_time_scale_shift="default", + resnet_time_scale_shift=norm_type, attn_num_head_channels=None, resnet_groups=norm_num_groups, - temb_channels=None, + temb_channels=temb_channels, ) # up @@ -196,19 +200,23 @@ def __init__( resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attn_num_head_channels=None, - temb_channels=None, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) self.gradient_checkpointing = False - def forward(self, z): + def forward(self, z, zq=None): sample = z sample = self.conv_in(sample) @@ -230,15 +238,18 @@ def custom_forward(*inputs): sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) else: # middle - sample = self.mid_block(sample) + sample = self.mid_block(sample, zq) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: - sample = up_block(sample) + sample = up_block(sample, zq) # post-process - sample = self.conv_norm_out(sample) + if zq is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, zq) sample = self.conv_act(sample) sample = self.conv_out(sample) diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index 65f734dccb2d..040447ba82c8 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -82,9 +82,11 @@ def __init__( norm_num_groups: int = 32, vq_embed_dim: Optional[int] = None, scaling_factor: float = 0.18215, + norm_type: str = "default" ): super().__init__() + # pass init params to Encoder self.encoder = Encoder( in_channels=in_channels, @@ -112,6 +114,7 @@ def __init__( layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, + norm_type=norm_type, ) def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: @@ -131,8 +134,8 @@ def decode( quant, emb_loss, info = self.quantize(h) else: quant = h - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) + quant2 = self.post_quant_conv(quant) + dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None) if not return_dict: return (dec,) diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py index d988f38506ea..b0d8b4b429a1 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py @@ -19,7 +19,7 @@ XLMRobertaTokenizerFast, ) -from ...models import UNet2DConditionModel +from ...models import UNet2DConditionModel, VQModel from ...pipelines import DiffusionPipeline from ...schedulers import UnCLIPScheduler from ...utils import ( @@ -30,6 +30,7 @@ ) from .text_encoder import MultilingualCLIP from .text_proj import KandinskyTextProjModel +from PIL import Image logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -44,6 +45,21 @@ def get_new_h_w(h, w): new_w += 1 return new_h * 8, new_w * 8 +def process_images(batch): + scaled = ( + ((batch + 1) * 127.5) + .round() + .clamp(0, 255) + .to(torch.uint8) + .to("cpu") + .permute(0, 2, 3, 1) + .numpy() + ) + images = [] + for i in range(scaled.shape[0]): + images.append(Image.fromarray(scaled[i])) + return images + class KandinskyPipeline(DiffusionPipeline): """ @@ -63,6 +79,8 @@ class KandinskyPipeline(DiffusionPipeline): Conditional U-Net architecture to denoise the image embedding. text_proj ([`KandinskyTextProjModel`]): Utility class to prepare and combine the embeddings before they are passed to the decoder. + decoder ([`VQModel`]): + Decoder to generate the image from the latents. """ def __init__( @@ -72,6 +90,7 @@ def __init__( text_proj: KandinskyTextProjModel, unet: UNet2DConditionModel, scheduler: UnCLIPScheduler, + decoder: VQModel ): super().__init__() @@ -81,6 +100,7 @@ def __init__( text_proj=text_proj, unet=unet, scheduler=scheduler, + decoder=decoder, ) def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): @@ -94,6 +114,13 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): latents = latents * scheduler.init_noise_sigma return latents + def get_image(self, latents): + images = self.decoder.decode(latents, force_not_quantize=True)["sample"] + images = process_images(images) + return images + + + def _encode_prompt( self, prompt, @@ -371,4 +398,5 @@ def __call__( _, latents = latents.chunk(2) - return latents + images = self.get_image(latents) + return images