Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
420 changes: 415 additions & 5 deletions scripts/convert_kandinsky_to_diffusers.py

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,24 @@ def forward(self, x, emb):
x = F.group_norm(x, self.num_groups, eps=self.eps)
x = x * (1 + scale) + shift
return x

class SpatialNorm(nn.Module):
"""
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
"""
def __init__(
self,
f_channels,
zq_channels,
):
super().__init__()
self.norm_layer = nn.GroupNorm(num_channels=f_channels,num_groups=32,eps=1e-6,affine=True)
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)

def forward(self, f, zq):
f_size = f.shape[-2:]
zq = F.interpolate(zq, size=f_size, mode="nearest")
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
14 changes: 9 additions & 5 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.nn as nn
import torch.nn.functional as F

from .attention import AdaGroupNorm
from .attention import AdaGroupNorm, SpatialNorm


class Upsample1D(nn.Module):
Expand Down Expand Up @@ -460,7 +460,7 @@ def __init__(
eps=1e-6,
non_linearity="swish",
skip_time_act=False,
time_embedding_norm="default", # default, scale_shift, ada_group
time_embedding_norm="default", # default, scale_shift, ada_group, spatial
kernel=None,
output_scale_factor=1.0,
use_in_shortcut=None,
Expand All @@ -487,6 +487,8 @@ def __init__(

if self.time_embedding_norm == "ada_group":
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm1 = SpatialNorm(in_channels, temb_channels)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)

Expand All @@ -497,7 +499,7 @@ def __init__(
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
elif self.time_embedding_norm == "ada_group":
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
self.time_emb_proj = None
else:
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
Expand All @@ -506,6 +508,8 @@ def __init__(

if self.time_embedding_norm == "ada_group":
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm2 = SpatialNorm(out_channels, temb_channels)
else:
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)

Expand Down Expand Up @@ -551,7 +555,7 @@ def __init__(
def forward(self, input_tensor, temb):
hidden_states = input_tensor

if self.time_embedding_norm == "ada_group":
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm1(hidden_states, temb)
else:
hidden_states = self.norm1(hidden_states)
Expand Down Expand Up @@ -579,7 +583,7 @@ def forward(self, input_tensor, temb):
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb

if self.time_embedding_norm == "ada_group":
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm2(hidden_states, temb)
else:
hidden_states = self.norm2(hidden_states)
Expand Down
51 changes: 39 additions & 12 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch.nn.functional as F
from torch import nn

from .attention import AdaGroupNorm
from .attention import AdaGroupNorm, AttentionBlock, SpatialNorm
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
from .dual_transformer_2d import DualTransformer2DModel
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
Expand Down Expand Up @@ -348,6 +348,7 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
temb_channels=temb_channels
)
elif up_block_type == "AttnUpDecoderBlock2D":
return AttnUpDecoderBlock2D(
Expand All @@ -360,6 +361,7 @@ def get_up_block(
resnet_groups=resnet_groups,
attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift,
temb_channels=temb_channels
)
elif up_block_type == "KUpBlock2D":
return KUpBlock2D(
Expand Down Expand Up @@ -406,7 +408,6 @@ def __init__(
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention

# there is always at least one resnet
resnets = [
ResnetBlock2D(
Expand Down Expand Up @@ -439,7 +440,6 @@ def __init__(
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
else:
attentions.append(None)

Expand All @@ -465,7 +465,8 @@ def forward(self, hidden_states, temb=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = attn(hidden_states)
hidden_states = attn(hidden_states, temb)

hidden_states = resnet(hidden_states, temb)

return hidden_states
Expand Down Expand Up @@ -1971,6 +1972,30 @@ def custom_forward(*inputs):
return hidden_states


class MOVQAttention(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need to create a new class anymore now this PR is merged #3387

def __init__(self, query_dim, temb_channels, attn_num_head_channels):
super().__init__()

self.norm = SpatialNorm(query_dim, temb_channels)
num_heads = query_dim // attn_num_head_channels if attn_num_head_channels is not None else 1
dim_head = attn_num_head_channels if attn_num_head_channels is not None else query_dim
self.attention = Attention(
query_dim=query_dim,
heads=num_heads,
dim_head=dim_head,
bias=True
)

def forward(self, hidden_states, temb):
residual = hidden_states
hidden_states = self.norm(hidden_states, temb).view(hidden_states.shape[0], hidden_states.shape[1], -1)
hidden_states = self.attention(hidden_states.transpose(1, 2), None, None).transpose(1, 2)
hidden_states = hidden_states.view(residual.shape)
hidden_states = hidden_states + residual
return hidden_states



class UpDecoderBlock2D(nn.Module):
def __init__(
self,
Expand All @@ -1985,6 +2010,7 @@ def __init__(
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_upsample=True,
temb_channels=None
):
super().__init__()
resnets = []
Expand All @@ -1996,7 +2022,7 @@ def __init__(
ResnetBlock2D(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=None,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
Expand All @@ -2014,9 +2040,9 @@ def __init__(
else:
self.upsamplers = None

def forward(self, hidden_states):
def forward(self, hidden_states, temb=None):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None)
hidden_states = resnet(hidden_states, temb=temb)

if self.upsamplers is not None:
for upsampler in self.upsamplers:
Expand All @@ -2040,6 +2066,7 @@ def __init__(
attn_num_head_channels=1,
output_scale_factor=1.0,
add_upsample=True,
temb_channels=None
):
super().__init__()
resnets = []
Expand All @@ -2052,7 +2079,7 @@ def __init__(
ResnetBlock2D(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=None,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
Expand All @@ -2075,7 +2102,6 @@ def __init__(
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)

self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
Expand All @@ -2085,10 +2111,10 @@ def __init__(
else:
self.upsamplers = None

def forward(self, hidden_states):
def forward(self, hidden_states, temb=None):
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb=None)
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states, temb=temb)
hidden_states = attn(hidden_states, temb)

if self.upsamplers is not None:
for upsampler in self.upsamplers:
Expand Down Expand Up @@ -2847,3 +2873,4 @@ def forward(
hidden_states = attn_output + hidden_states

return hidden_states

29 changes: 20 additions & 9 deletions src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from ..utils import BaseOutput, randn_tensor
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block

from .attention import SpatialNorm

@dataclass
class DecoderOutput(BaseOutput):
Expand Down Expand Up @@ -149,6 +149,7 @@ def __init__(
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
norm_type="default", # default, spatial
):
super().__init__()
self.layers_per_block = layers_per_block
Expand All @@ -164,16 +165,19 @@ def __init__(
self.mid_block = None
self.up_blocks = nn.ModuleList([])


temb_channels = in_channels if norm_type == "spatial" else None

# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
resnet_time_scale_shift=norm_type,
attn_num_head_channels=None,
resnet_groups=norm_num_groups,
temb_channels=None,
temb_channels=temb_channels,
)

# up
Expand All @@ -196,19 +200,23 @@ def __init__(
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=None,
temb_channels=None,
temb_channels=temb_channels,
resnet_time_scale_shift=norm_type,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel

# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)

self.gradient_checkpointing = False

def forward(self, z):
def forward(self, z, zq=None):
sample = z
sample = self.conv_in(sample)

Expand All @@ -230,15 +238,18 @@ def custom_forward(*inputs):
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
else:
# middle
sample = self.mid_block(sample)
sample = self.mid_block(sample, zq)
sample = sample.to(upscale_dtype)

# up
for up_block in self.up_blocks:
sample = up_block(sample)
sample = up_block(sample, zq)

# post-process
sample = self.conv_norm_out(sample)
if zq is None:
sample = self.conv_norm_out(sample)
else:
sample = self.conv_norm_out(sample, zq)
sample = self.conv_act(sample)
sample = self.conv_out(sample)

Expand Down
7 changes: 5 additions & 2 deletions src/diffusers/models/vq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,11 @@ def __init__(
norm_num_groups: int = 32,
vq_embed_dim: Optional[int] = None,
scaling_factor: float = 0.18215,
norm_type: str = "default"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
norm_type: str = "default"
norm_type: str = "default" # "default", "spatial"

):
super().__init__()


# pass init params to Encoder
self.encoder = Encoder(
in_channels=in_channels,
Expand Down Expand Up @@ -112,6 +114,7 @@ def __init__(
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_type=norm_type,
)

def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
Expand All @@ -131,8 +134,8 @@ def decode(
quant, emb_loss, info = self.quantize(h)
else:
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
quant2 = self.post_quant_conv(quant)
dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)

if not return_dict:
return (dec,)
Expand Down
Loading