diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index ca72669a9076..462f87d8d850 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -1,3 +1,7 @@ +import string +from abc import abstractmethod + +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -54,6 +58,18 @@ def nonlinearity(x, swish=1.0): return x * F.sigmoid(x * float(swish)) +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + class Upsample(nn.Module): """ An upsampling layer with an optional convolution. @@ -134,154 +150,713 @@ def forward(self, x): return self.op(x) -class UNetUpsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) +# class UNetUpsample(nn.Module): +# def __init__(self, in_channels, with_conv): +# super().__init__() +# self.with_conv = with_conv +# if self.with_conv: +# self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) +# +# def forward(self, x): +# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") +# if self.with_conv: +# x = self.conv(x) +# return x +# +# +# class GlideUpsample(nn.Module): +# """ +# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param +use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If +3D, then # upsampling occurs in the inner-two dimensions. #""" +# +# def __init__(self, channels, use_conv, dims=2, out_channels=None): +# super().__init__() +# self.channels = channels +# self.out_channels = out_channels or channels +# self.use_conv = use_conv +# self.dims = dims +# if use_conv: +# self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) +# +# def forward(self, x): +# assert x.shape[1] == self.channels +# if self.dims == 3: +# x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") +# else: +# x = F.interpolate(x, scale_factor=2, mode="nearest") +# if self.use_conv: +# x = self.conv(x) +# return x +# +# +# class LDMUpsample(nn.Module): +# """ +# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param # +use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If +3D, then # upsampling occurs in the inner-two dimensions. #""" +# +# def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): +# super().__init__() +# self.channels = channels +# self.out_channels = out_channels or channels +# self.use_conv = use_conv +# self.dims = dims +# if use_conv: +# self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) +# +# def forward(self, x): +# assert x.shape[1] == self.channels +# if self.dims == 3: +# x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") +# else: +# x = F.interpolate(x, scale_factor=2, mode="nearest") +# if self.use_conv: +# x = self.conv(x) +# return x +# +# +# class GradTTSUpsample(torch.nn.Module): +# def __init__(self, dim): +# super(Upsample, self).__init__() +# self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) +# +# def forward(self, x): +# return self.conv(x) +# +# +# TODO (patil-suraj): needs test +# class Upsample1d(nn.Module): +# def __init__(self, dim): +# super().__init__() +# self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) +# +# def forward(self, x): +# return self.conv(x) - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x +# RESNETS -class GlideUpsample(nn.Module): +# unet_glide.py & unet_ldm.py +class ResBlock(TimestepBlock): """ - An upsampling layer with an optional convolution. - - :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is - applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. + A residual block that can optionally change the number of channels. + + :param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param + use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing + on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for + downsampling. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None): + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): super().__init__() self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout self.out_channels = out_channels or channels self.use_conv = use_conv - self.dims = dims - if use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels, swish=1.0), + nn.Identity(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, use_conv=False, dims=dims) + self.x_upd = Upsample(channels, use_conv=False, dims=dims) + elif down: + self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") + self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0), + nn.SiLU() if use_scale_shift_norm else nn.Identity(), + nn.Dropout(p=dropout), + zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + :param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h - def forward(self, x): - assert x.shape[1] == self.channels - if self.dims == 3: - x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") + +# unet.py +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +# unet_grad_tts.py +class ResnetBlockGradTTS(torch.nn.Module): + def __init__(self, dim, dim_out, time_emb_dim, groups=8): + super(ResnetBlockGradTTS, self).__init__() + self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out)) + + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) + if dim != dim_out: + self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) else: - x = F.interpolate(x, scale_factor=2, mode="nearest") - if self.use_conv: - x = self.conv(x) - return x + self.res_conv = torch.nn.Identity() + + def forward(self, x, mask, time_emb): + h = self.block1(x, mask) + h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) + h = self.block2(h, mask) + output = h + self.res_conv(x * mask) + return output + + +# unet_rl.py +class ResidualTemporalBlock(nn.Module): + def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): + super().__init__() + + self.blocks = nn.ModuleList( + [ + Conv1dBlock(inp_channels, out_channels, kernel_size), + Conv1dBlock(out_channels, out_channels, kernel_size), + ] + ) + + self.time_mlp = nn.Sequential( + nn.Mish(), + nn.Linear(embed_dim, out_channels), + RearrangeDim(), + # Rearrange("batch t -> batch t 1"), + ) + + self.residual_conv = ( + nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() + ) + + def forward(self, x, t): + """ + x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x + out_channels x horizon ] + """ + out = self.blocks[0](x) + self.time_mlp(t) + out = self.blocks[1](out) + return out + self.residual_conv(x) + + +# unet_score_estimation.py +class ResnetBlockBigGANpp(nn.Module): + def __init__( + self, + act, + in_ch, + out_ch=None, + temb_dim=None, + up=False, + down=False, + dropout=0.1, + fir=False, + fir_kernel=(1, 3, 3, 1), + skip_rescale=True, + init_scale=0.0, + ): + super().__init__() + + out_ch = out_ch if out_ch else in_ch + self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) + self.up = up + self.down = down + self.fir = fir + self.fir_kernel = fir_kernel + + self.Conv_0 = conv3x3(in_ch, out_ch) + if temb_dim is not None: + self.Dense_0 = nn.Linear(temb_dim, out_ch) + self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) + nn.init.zeros_(self.Dense_0.bias) + + self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) + self.Dropout_0 = nn.Dropout(dropout) + self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) + if in_ch != out_ch or up or down: + self.Conv_2 = conv1x1(in_ch, out_ch) + + self.skip_rescale = skip_rescale + self.act = act + self.in_ch = in_ch + self.out_ch = out_ch + + def forward(self, x, temb=None): + h = self.act(self.GroupNorm_0(x)) + + if self.up: + if self.fir: + h = upsample_2d(h, self.fir_kernel, factor=2) + x = upsample_2d(x, self.fir_kernel, factor=2) + else: + h = naive_upsample_2d(h, factor=2) + x = naive_upsample_2d(x, factor=2) + elif self.down: + if self.fir: + h = downsample_2d(h, self.fir_kernel, factor=2) + x = downsample_2d(x, self.fir_kernel, factor=2) + else: + h = naive_downsample_2d(h, factor=2) + x = naive_downsample_2d(x, factor=2) + + h = self.Conv_0(h) + # Add bias to each feature map conditioned on the time embedding + if temb is not None: + h += self.Dense_0(self.act(temb))[:, :, None, None] + h = self.act(self.GroupNorm_1(h)) + h = self.Dropout_0(h) + h = self.Conv_1(h) + + if self.in_ch != self.out_ch or self.up or self.down: + x = self.Conv_2(x) + + if not self.skip_rescale: + return x + h + else: + return (x + h) / np.sqrt(2.0) + + +# unet_score_estimation.py +class ResnetBlockDDPMpp(nn.Module): + """ResBlock adapted from DDPM.""" + + def __init__( + self, + act, + in_ch, + out_ch=None, + temb_dim=None, + conv_shortcut=False, + dropout=0.1, + skip_rescale=False, + init_scale=0.0, + ): + super().__init__() + out_ch = out_ch if out_ch else in_ch + self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) + self.Conv_0 = conv3x3(in_ch, out_ch) + if temb_dim is not None: + self.Dense_0 = nn.Linear(temb_dim, out_ch) + self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) + nn.init.zeros_(self.Dense_0.bias) + self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) + self.Dropout_0 = nn.Dropout(dropout) + self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) + if in_ch != out_ch: + if conv_shortcut: + self.Conv_2 = conv3x3(in_ch, out_ch) + else: + self.NIN_0 = NIN(in_ch, out_ch) + + self.skip_rescale = skip_rescale + self.act = act + self.out_ch = out_ch + self.conv_shortcut = conv_shortcut + + def forward(self, x, temb=None): + h = self.act(self.GroupNorm_0(x)) + h = self.Conv_0(h) + if temb is not None: + h += self.Dense_0(self.act(temb))[:, :, None, None] + h = self.act(self.GroupNorm_1(h)) + h = self.Dropout_0(h) + h = self.Conv_1(h) + if x.shape[1] != self.out_ch: + if self.conv_shortcut: + x = self.Conv_2(x) + else: + x = self.NIN_0(x) + if not self.skip_rescale: + return x + h + else: + return (x + h) / np.sqrt(2.0) -class LDMUpsample(nn.Module): +# HELPER Modules + + +def normalization(channels, swish=0.0): """ - An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param - use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. - If 3D, then - upsampling occurs in the inner-two dimensions. + Make a standard normalization layer, with an optional swish activation. + + :param channels: number of input channels. :return: an nn.Module for normalization. + """ + return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) + + +class GroupNorm32(nn.GroupNorm): + def __init__(self, num_groups, num_channels, swish, eps=1e-5): + super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps) + self.swish = swish + + def forward(self, x): + y = super().forward(x.float()).to(x.dtype) + if self.swish == 1.0: + y = F.silu(y) + elif self.swish: + y = y * F.sigmoid(y * float(self.swish)) + return y + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class Mish(torch.nn.Module): + def forward(self, x): + return x * torch.tanh(torch.nn.functional.softplus(x)) + + +class Block(torch.nn.Module): + def __init__(self, dim, dim_out, groups=8): + super(Block, self).__init__() + self.block = torch.nn.Sequential( + torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish() + ) + + def forward(self, x, mask): + output = self.block(x * mask) + return output * mask - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + +class Conv1dBlock(nn.Module): + """ + Conv1d --> GroupNorm --> Mish + """ + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - if use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + self.block = nn.Sequential( + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), + RearrangeDim(), + # Rearrange("batch channels horizon -> batch channels 1 horizon"), + nn.GroupNorm(n_groups, out_channels), + RearrangeDim(), + # Rearrange("batch channels 1 horizon -> batch channels horizon"), + nn.Mish(), + ) def forward(self, x): - assert x.shape[1] == self.channels - if self.dims == 3: - x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") + return self.block(x) + + +class RearrangeDim(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, tensor): + if len(tensor.shape) == 2: + return tensor[:, :, None] + if len(tensor.shape) == 3: + return tensor[:, :, None, :] + elif len(tensor.shape) == 4: + return tensor[:, :, 0, :] else: - x = F.interpolate(x, scale_factor=2, mode="nearest") - if self.use_conv: - x = self.conv(x) - return x + raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") + + +def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0): + """1x1 convolution with DDPM initialization.""" + conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias) + conv.weight.data = default_init(init_scale)(conv.weight.data.shape) + nn.init.zeros_(conv.bias) + return conv + + +def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1): + """3x3 convolution with DDPM initialization.""" + conv = nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias + ) + conv.weight.data = default_init(init_scale)(conv.weight.data.shape) + nn.init.zeros_(conv.bias) + return conv + + +def default_init(scale=1.0): + """The same initialization used in DDPM.""" + scale = 1e-10 if scale == 0 else scale + return variance_scaling(scale, "fan_avg", "uniform") + + +def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"): + """Ported from JAX.""" + + def _compute_fans(shape, in_axis=1, out_axis=0): + receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] + fan_in = shape[in_axis] * receptive_field_size + fan_out = shape[out_axis] * receptive_field_size + return fan_in, fan_out + + def init(shape, dtype=dtype, device=device): + fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) + if mode == "fan_in": + denominator = fan_in + elif mode == "fan_out": + denominator = fan_out + elif mode == "fan_avg": + denominator = (fan_in + fan_out) / 2 + else: + raise ValueError("invalid mode for variance scaling initializer: {}".format(mode)) + variance = scale / denominator + if distribution == "normal": + return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance) + elif distribution == "uniform": + return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance) + else: + raise ValueError("invalid distribution for variance scaling initializer") + return init -class GradTTSUpsample(torch.nn.Module): - def __init__(self, dim): - super(Upsample, self).__init__() - self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) - def forward(self, x): - return self.conv(x) +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) -# TODO (patil-suraj): needs test -class Upsample1d(nn.Module): - def __init__(self, dim): +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) + + +def upsample_2d(x, k=None, factor=2, gain=1): + r"""Upsample a batch of 2D images with the given filter. + + Args: + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given + filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified + `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a: + multiple of the upsampling factor. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if k is None: + k = [1] * factor + k = _setup_kernel(k) * (gain * (factor**2)) + p = k.shape[0] - factor + return upfirdn2d(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) + + +def downsample_2d(x, k=None, factor=2, gain=1): + r"""Downsample a batch of 2D images with the given filter. + + Args: + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the + given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the + specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its + shape is a multiple of the downsampling factor. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` + """ + + assert isinstance(factor, int) and factor >= 1 + if k is None: + k = [1] * factor + k = _setup_kernel(k) * gain + p = k.shape[0] - factor + return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) + + +def naive_upsample_2d(x, factor=2): + _N, C, H, W = x.shape + x = torch.reshape(x, (-1, C, H, 1, W, 1)) + x = x.repeat(1, 1, 1, factor, 1, factor) + return torch.reshape(x, (-1, C, H * factor, W * factor)) + + +def naive_downsample_2d(x, factor=2): + _N, C, H, W = x.shape + x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor)) + return torch.mean(x, dim=(3, 5)) + + +class NIN(nn.Module): + def __init__(self, in_dim, num_units, init_scale=0.1): super().__init__() - self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) + self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True) + self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True) def forward(self, x): - return self.conv(x) - - -# class ResnetBlock(nn.Module): -# def __init__( -# self, -# *, -# in_channels, -# out_channels=None, -# conv_shortcut=False, -# dropout, -# temb_channels=512, -# use_scale_shift_norm=False, -# ): -# super().__init__() -# self.in_channels = in_channels -# out_channels = in_channels if out_channels is None else out_channels -# self.out_channels = out_channels -# self.use_conv_shortcut = conv_shortcut -# self.use_scale_shift_norm = use_scale_shift_norm - -# self.norm1 = Normalize(in_channels) -# self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - -# temp_out_channles = 2 * out_channels if use_scale_shift_norm else out_channels -# self.temb_proj = torch.nn.Linear(temb_channels, temp_out_channles) - -# self.norm2 = Normalize(out_channels) -# self.dropout = torch.nn.Dropout(dropout) -# self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) -# if self.in_channels != self.out_channels: -# if self.use_conv_shortcut: -# self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) -# else: -# self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - -# def forward(self, x, temb): -# h = x -# h = self.norm1(h) -# h = nonlinearity(h) -# h = self.conv1(h) - -# # TODO: check if this broadcasting works correctly for 1D and 3D -# temb = self.temb_proj(nonlinearity(temb))[:, :, None, None] - -# if self.use_scale_shift_norm: -# out_norm, out_rest = self.out_layers[0], self.out_layers[1:] -# scale, shift = torch.chunk(temb, 2, dim=1) -# h = self.norm2(h) * (1 + scale) + shift -# h = out_rest(h) -# else: -# h = h + temb -# h = self.norm2(h) -# h = nonlinearity(h) -# h = self.dropout(h) -# h = self.conv2(h) - -# if self.in_channels != self.out_channels: -# if self.use_conv_shortcut: -# x = self.conv_shortcut(x) -# else: -# x = self.nin_shortcut(x) - -# return x + h + x = x.permute(0, 2, 3, 1) + y = contract_inner(x, self.W) + self.b + return y.permute(0, 3, 1, 2) + + +def _setup_kernel(k): + k = np.asarray(k, dtype=np.float32) + if k.ndim == 1: + k = np.outer(k, k) + k /= np.sum(k) + assert k.ndim == 2 + assert k.shape[0] == k.shape[1] + return k + + +def contract_inner(x, y): + """tensordot(x, y, 1).""" + x_chars = list(string.ascii_lowercase[: len(x.shape)]) + y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)]) + y_chars[0] = x_chars[-1] # first axis of y and last of x get summed + out_chars = x_chars[:-1] + y_chars[1:] + return _einsum(x_chars, y_chars, out_chars, x, y) + + +def _einsum(a, b, c, x, y): + einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c)) + return torch.einsum(einsum_str, x, y) diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index 13765e1f8b5d..5bc13f80f958 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -22,7 +22,7 @@ from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import get_timestep_embedding -from .resnet import Downsample, Upsample +from .resnet import Downsample, ResnetBlock, Upsample def nonlinearity(x): @@ -34,46 +34,46 @@ def Normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) -class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x + h +# class ResnetBlock(nn.Module): +# def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): +# super().__init__() +# self.in_channels = in_channels +# out_channels = in_channels if out_channels is None else out_channels +# self.out_channels = out_channels +# self.use_conv_shortcut = conv_shortcut +# +# self.norm1 = Normalize(in_channels) +# self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) +# self.temb_proj = torch.nn.Linear(temb_channels, out_channels) +# self.norm2 = Normalize(out_channels) +# self.dropout = torch.nn.Dropout(dropout) +# self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) +# if self.in_channels != self.out_channels: +# if self.use_conv_shortcut: +# self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) +# else: +# self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) +# +# def forward(self, x, temb): +# h = x +# h = self.norm1(h) +# h = nonlinearity(h) +# h = self.conv1(h) +# +# h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] +# +# h = self.norm2(h) +# h = nonlinearity(h) +# h = self.dropout(h) +# h = self.conv2(h) +# +# if self.in_channels != self.out_channels: +# if self.use_conv_shortcut: +# x = self.conv_shortcut(x) +# else: +# x = self.nin_shortcut(x) +# +# return x + h class UNetModel(ModelMixin, ConfigMixin): @@ -127,7 +127,6 @@ def __init__( for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() - attn_2 = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): @@ -142,7 +141,6 @@ def __init__( down = nn.Module() down.block = block down.attn = attn - down.attn_2 = attn_2 if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0) curr_res = curr_res // 2 diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 53763ddaa08c..c9f06a2d80b4 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -8,7 +8,7 @@ from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import get_timestep_embedding -from .resnet import Downsample, Upsample +from .resnet import Downsample, ResBlock, TimestepBlock, Upsample def convert_module_to_f16(l): @@ -96,16 +96,14 @@ def zero_module(module): return module -class TimestepBlock(nn.Module): - """ - Any module where forward() takes timestep embeddings as a second argument. - """ - - @abstractmethod - def forward(self, x, emb): - """ - Apply the module to `x` given `emb` timestep embeddings. - """ +# class TimestepBlock(nn.Module): +# """ +# Any module where forward() takes timestep embeddings as a second argument. #""" +# +# @abstractmethod +# def forward(self, x, emb): +# """ +# Apply the module to `x` given `emb` timestep embeddings. #""" class TimestepEmbedSequential(nn.Sequential, TimestepBlock): @@ -124,106 +122,99 @@ def forward(self, x, emb, encoder_out=None): return x -class ResBlock(TimestepBlock): - """ - A residual block that can optionally change the number of channels. - - :param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels. - :param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param - use_conv: if True and out_channels is specified, use a spatial - convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. - :param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing - on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for - downsampling. - """ - - def __init__( - self, - channels, - emb_channels, - dropout, - out_channels=None, - use_conv=False, - use_scale_shift_norm=False, - dims=2, - use_checkpoint=False, - up=False, - down=False, - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.dropout = dropout - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_checkpoint = use_checkpoint - self.use_scale_shift_norm = use_scale_shift_norm - - self.in_layers = nn.Sequential( - normalization(channels, swish=1.0), - nn.Identity(), - conv_nd(dims, channels, self.out_channels, 3, padding=1), - ) - - self.updown = up or down - - if up: - self.h_upd = Upsample(channels, use_conv=False, dims=dims) - self.x_upd = Upsample(channels, use_conv=False, dims=dims) - elif down: - self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") - self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") - else: - self.h_upd = self.x_upd = nn.Identity() - - self.emb_layers = nn.Sequential( - nn.SiLU(), - linear( - emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, - ), - ) - self.out_layers = nn.Sequential( - normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0), - nn.SiLU() if use_scale_shift_norm else nn.Identity(), - nn.Dropout(p=dropout), - zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - elif use_conv: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) - else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) - - def forward(self, x, emb): - """ - Apply the block to a Tensor, conditioned on a timestep embedding. - - :param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings. - :return: an [N x C x ...] Tensor of outputs. - """ - if self.updown: - in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] - h = in_rest(x) - h = self.h_upd(h) - x = self.x_upd(x) - h = in_conv(h) - else: - h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] - if self.use_scale_shift_norm: - out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = torch.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift - h = out_rest(h) - else: - h = h + emb_out - h = self.out_layers(h) - return self.skip_connection(x) + h +# class ResBlock(TimestepBlock): +# """ +# A residual block that can optionally change the number of channels. # # :param channels: the number of input +channels. :param emb_channels: the number of timestep embedding channels. # :param dropout: the rate of dropout. :param +out_channels: if specified, the number of out channels. :param # use_conv: if True and out_channels is specified, use a +spatial # convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param +dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing # on this +module. :param up: if True, use this block for upsampling. :param down: if True, use this block for # downsampling. #""" +# +# def __init__( +# self, +# channels, +# emb_channels, +# dropout, +# out_channels=None, +# use_conv=False, +# use_scale_shift_norm=False, +# dims=2, +# use_checkpoint=False, +# up=False, +# down=False, +# ): +# super().__init__() +# self.channels = channels +# self.emb_channels = emb_channels +# self.dropout = dropout +# self.out_channels = out_channels or channels +# self.use_conv = use_conv +# self.use_checkpoint = use_checkpoint +# self.use_scale_shift_norm = use_scale_shift_norm +# +# self.in_layers = nn.Sequential( +# normalization(channels, swish=1.0), +# nn.Identity(), +# conv_nd(dims, channels, self.out_channels, 3, padding=1), +# ) +# +# self.updown = up or down +# +# if up: +# self.h_upd = Upsample(channels, use_conv=False, dims=dims) +# self.x_upd = Upsample(channels, use_conv=False, dims=dims) +# elif down: +# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") +# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") +# else: +# self.h_upd = self.x_upd = nn.Identity() +# +# self.emb_layers = nn.Sequential( +# nn.SiLU(), +# linear( +# emb_channels, +# 2 * self.out_channels if use_scale_shift_norm else self.out_channels, +# ), +# ) +# self.out_layers = nn.Sequential( +# normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0), +# nn.SiLU() if use_scale_shift_norm else nn.Identity(), +# nn.Dropout(p=dropout), +# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), +# ) +# +# if self.out_channels == channels: +# self.skip_connection = nn.Identity() +# elif use_conv: +# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) +# else: +# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) +# +# def forward(self, x, emb): +# """ +# Apply the block to a Tensor, conditioned on a timestep embedding. # # :param x: an [N x C x ...] Tensor of features. +:param emb: an [N x emb_channels] Tensor of timestep embeddings. # :return: an [N x C x ...] Tensor of outputs. #""" +# if self.updown: +# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] +# h = in_rest(x) +# h = self.h_upd(h) +# x = self.x_upd(x) +# h = in_conv(h) +# else: +# h = self.in_layers(x) +# emb_out = self.emb_layers(emb).type(h.dtype) +# while len(emb_out.shape) < len(h.shape): +# emb_out = emb_out[..., None] +# if self.use_scale_shift_norm: +# out_norm, out_rest = self.out_layers[0], self.out_layers[1:] +# scale, shift = torch.chunk(emb_out, 2, dim=1) +# h = out_norm(h) * (1 + scale) + shift +# h = out_rest(h) +# else: +# h = h + emb_out +# h = self.out_layers(h) +# return self.skip_connection(x) + h class GlideUNetModel(ModelMixin, ConfigMixin): diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index bc0c1e7a222f..cb6122c897cd 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -4,7 +4,9 @@ from ..modeling_utils import ModelMixin from .attention import LinearAttention from .embeddings import get_timestep_embedding -from .resnet import Downsample, Upsample +from .resnet import Downsample +from .resnet import ResnetBlockGradTTS as ResnetBlock +from .resnet import Upsample class Mish(torch.nn.Module): @@ -34,24 +36,24 @@ def forward(self, x, mask): return output * mask -class ResnetBlock(torch.nn.Module): - def __init__(self, dim, dim_out, time_emb_dim, groups=8): - super(ResnetBlock, self).__init__() - self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out)) - - self.block1 = Block(dim, dim_out, groups=groups) - self.block2 = Block(dim_out, dim_out, groups=groups) - if dim != dim_out: - self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) - else: - self.res_conv = torch.nn.Identity() - - def forward(self, x, mask, time_emb): - h = self.block1(x, mask) - h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) - h = self.block2(h, mask) - output = h + self.res_conv(x * mask) - return output +# class ResnetBlock(torch.nn.Module): +# def __init__(self, dim, dim_out, time_emb_dim, groups=8): +# super(ResnetBlock, self).__init__() +# self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out)) +# +# self.block1 = Block(dim, dim_out, groups=groups) +# self.block2 = Block(dim_out, dim_out, groups=groups) +# if dim != dim_out: +# self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) +# else: +# self.res_conv = torch.nn.Identity() +# +# def forward(self, x, mask, time_emb): +# h = self.block1(x, mask) +# h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) +# h = self.block2(h, mask) +# output = h + self.res_conv(x * mask) +# return output class Residual(torch.nn.Module): diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index 0012886a5e1e..650b262c9626 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -11,7 +11,7 @@ from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import get_timestep_embedding -from .resnet import Downsample, Upsample +from .resnet import Downsample, ResBlock, TimestepBlock, Upsample def exists(val): @@ -359,16 +359,14 @@ def forward(self, x): return x[:, :, 0] -class TimestepBlock(nn.Module): - """ - Any module where forward() takes timestep embeddings as a second argument. - """ - - @abstractmethod - def forward(self, x, emb): - """ - Apply the module to `x` given `emb` timestep embeddings. - """ +# class TimestepBlock(nn.Module): +# """ +# Any module where forward() takes timestep embeddings as a second argument. #""" +# +# @abstractmethod +# def forward(self, x, emb): +# """ +# Apply the module to `x` given `emb` timestep embeddings. #""" class TimestepEmbedSequential(nn.Sequential, TimestepBlock): @@ -387,99 +385,97 @@ def forward(self, x, emb, context=None): return x -class ResBlock(TimestepBlock): - """ - A residual block that can optionally change the number of channels. :param channels: the number of input channels. - :param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param - out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use - a spatial - convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. - :param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing - on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for - downsampling. - """ - - def __init__( - self, - channels, - emb_channels, - dropout, - out_channels=None, - use_conv=False, - use_scale_shift_norm=False, - dims=2, - use_checkpoint=False, - up=False, - down=False, - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.dropout = dropout - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_checkpoint = use_checkpoint - self.use_scale_shift_norm = use_scale_shift_norm - - self.in_layers = nn.Sequential( - normalization(channels), - nn.SiLU(), - conv_nd(dims, channels, self.out_channels, 3, padding=1), - ) - - self.updown = up or down - - if up: - self.h_upd = Upsample(channels, use_conv=False, dims=dims) - self.x_upd = Upsample(channels, use_conv=False, dims=dims) - elif down: - self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") - self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") - else: - self.h_upd = self.x_upd = nn.Identity() - - self.emb_layers = nn.Sequential( - nn.SiLU(), - linear( - emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, - ), - ) - self.out_layers = nn.Sequential( - normalization(self.out_channels), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - elif use_conv: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) - else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) - - def forward(self, x, emb): - if self.updown: - in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] - h = in_rest(x) - h = self.h_upd(h) - x = self.x_upd(x) - h = in_conv(h) - else: - h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] - if self.use_scale_shift_norm: - out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = torch.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift - h = out_rest(h) - else: - h = h + emb_out - h = self.out_layers(h) - return self.skip_connection(x) + h +# class A_ResBlock(TimestepBlock): +# """ +# A residual block that can optionally change the number of channels. :param channels: the number of input channels. # +:param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param # +out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use # a +spatial # convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param +dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing # on this +module. :param up: if True, use this block for upsampling. :param down: if True, use this block for # downsampling. #""" +# +# def __init__( +# self, +# channels, +# emb_channels, +# dropout, +# out_channels=None, +# use_conv=False, +# use_scale_shift_norm=False, +# dims=2, +# use_checkpoint=False, +# up=False, +# down=False, +# ): +# super().__init__() +# self.channels = channels +# self.emb_channels = emb_channels +# self.dropout = dropout +# self.out_channels = out_channels or channels +# self.use_conv = use_conv +# self.use_checkpoint = use_checkpoint +# self.use_scale_shift_norm = use_scale_shift_norm +# +# self.in_layers = nn.Sequential( +# normalization(channels), +# nn.SiLU(), +# conv_nd(dims, channels, self.out_channels, 3, padding=1), +# ) +# +# self.updown = up or down +# +# if up: +# self.h_upd = Upsample(channels, use_conv=False, dims=dims) +# self.x_upd = Upsample(channels, use_conv=False, dims=dims) +# elif down: +# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") +# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") +# else: +# self.h_upd = self.x_upd = nn.Identity() +# +# self.emb_layers = nn.Sequential( +# nn.SiLU(), +# linear( +# emb_channels, +# 2 * self.out_channels if use_scale_shift_norm else self.out_channels, +# ), +# ) +# self.out_layers = nn.Sequential( +# normalization(self.out_channels), +# nn.SiLU(), +# nn.Dropout(p=dropout), +# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), +# ) +# +# if self.out_channels == channels: +# self.skip_connection = nn.Identity() +# elif use_conv: +# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) +# else: +# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) +# +# def forward(self, x, emb): +# if self.updown: +# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] +# h = in_rest(x) +# h = self.h_upd(h) +# x = self.x_upd(x) +# h = in_conv(h) +# else: +# h = self.in_layers(x) +# emb_out = self.emb_layers(emb).type(h.dtype) +# while len(emb_out.shape) < len(h.shape): +# emb_out = emb_out[..., None] +# if self.use_scale_shift_norm: +# out_norm, out_rest = self.out_layers[0], self.out_layers[1:] +# scale, shift = torch.chunk(emb_out, 2, dim=1) +# h = out_norm(h) * (1 + scale) + shift +# h = out_rest(h) +# else: +# h = h + emb_out +# h = self.out_layers(h) +# return self.skip_connection(x) + h +# class QKVAttention(nn.Module): diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index e41726dbed8c..c4e59b1001df 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -6,6 +6,7 @@ from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding +from .resnet import ResidualTemporalBlock class SinusoidalPosEmb(nn.Module): @@ -72,36 +73,35 @@ def forward(self, x): return self.block(x) -class ResidualTemporalBlock(nn.Module): - def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): - super().__init__() - - self.blocks = nn.ModuleList( - [ - Conv1dBlock(inp_channels, out_channels, kernel_size), - Conv1dBlock(out_channels, out_channels, kernel_size), - ] - ) - - self.time_mlp = nn.Sequential( - nn.Mish(), - nn.Linear(embed_dim, out_channels), - RearrangeDim(), - # Rearrange("batch t -> batch t 1"), - ) - - self.residual_conv = ( - nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() - ) - - def forward(self, x, t): - """ - x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x - out_channels x horizon ] - """ - out = self.blocks[0](x) + self.time_mlp(t) - out = self.blocks[1](out) - return out + self.residual_conv(x) +# class ResidualTemporalBlock(nn.Module): +# def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): +# super().__init__() +# +# self.blocks = nn.ModuleList( +# [ +# Conv1dBlock(inp_channels, out_channels, kernel_size), +# Conv1dBlock(out_channels, out_channels, kernel_size), +# ] +# ) +# +# self.time_mlp = nn.Sequential( +# nn.Mish(), +# nn.Linear(embed_dim, out_channels), +# RearrangeDim(), +# Rearrange("batch t -> batch t 1"), +# ) +# +# self.residual_conv = ( +# nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() +# ) +# +# def forward(self, x, t): +# """ +# x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x # +out_channels x horizon ] #""" +# out = self.blocks[0](x) + self.time_mlp(t) +# out = self.blocks[1](out) +# return out + self.residual_conv(x) class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 17218a7a7ede..c03b513ca6ff 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -28,6 +28,7 @@ from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import GaussianFourierProjection, get_timestep_embedding +from .resnet import ResnetBlockBigGANpp, ResnetBlockDDPMpp def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): @@ -299,7 +300,7 @@ def downsample_2d(x, k=None, factor=2, gain=1): return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) -def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0): +def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0): """1x1 convolution with DDPM initialization.""" conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias) conv.weight.data = default_init(init_scale)(conv.weight.data.shape) @@ -307,7 +308,7 @@ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, pad return conv -def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1): +def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1): """3x3 convolution with DDPM initialization.""" conv = nn.Conv2d( in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias @@ -317,10 +318,6 @@ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_sc return conv -conv1x1 = ddpm_conv1x1 -conv3x3 = ddpm_conv3x3 - - def _einsum(a, b, c, x, y): einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c)) return torch.einsum(einsum_str, x, y) @@ -494,135 +491,135 @@ def forward(self, x): return x -class ResnetBlockDDPMpp(nn.Module): - """ResBlock adapted from DDPM.""" - - def __init__( - self, - act, - in_ch, - out_ch=None, - temb_dim=None, - conv_shortcut=False, - dropout=0.1, - skip_rescale=False, - init_scale=0.0, - ): - super().__init__() - out_ch = out_ch if out_ch else in_ch - self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) - self.Conv_0 = conv3x3(in_ch, out_ch) - if temb_dim is not None: - self.Dense_0 = nn.Linear(temb_dim, out_ch) - self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) - nn.init.zeros_(self.Dense_0.bias) - self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) - self.Dropout_0 = nn.Dropout(dropout) - self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) - if in_ch != out_ch: - if conv_shortcut: - self.Conv_2 = conv3x3(in_ch, out_ch) - else: - self.NIN_0 = NIN(in_ch, out_ch) - - self.skip_rescale = skip_rescale - self.act = act - self.out_ch = out_ch - self.conv_shortcut = conv_shortcut - - def forward(self, x, temb=None): - h = self.act(self.GroupNorm_0(x)) - h = self.Conv_0(h) - if temb is not None: - h += self.Dense_0(self.act(temb))[:, :, None, None] - h = self.act(self.GroupNorm_1(h)) - h = self.Dropout_0(h) - h = self.Conv_1(h) - if x.shape[1] != self.out_ch: - if self.conv_shortcut: - x = self.Conv_2(x) - else: - x = self.NIN_0(x) - if not self.skip_rescale: - return x + h - else: - return (x + h) / np.sqrt(2.0) - - -class ResnetBlockBigGANpp(nn.Module): - def __init__( - self, - act, - in_ch, - out_ch=None, - temb_dim=None, - up=False, - down=False, - dropout=0.1, - fir=False, - fir_kernel=(1, 3, 3, 1), - skip_rescale=True, - init_scale=0.0, - ): - super().__init__() - - out_ch = out_ch if out_ch else in_ch - self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) - self.up = up - self.down = down - self.fir = fir - self.fir_kernel = fir_kernel - - self.Conv_0 = conv3x3(in_ch, out_ch) - if temb_dim is not None: - self.Dense_0 = nn.Linear(temb_dim, out_ch) - self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) - nn.init.zeros_(self.Dense_0.bias) - - self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) - self.Dropout_0 = nn.Dropout(dropout) - self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) - if in_ch != out_ch or up or down: - self.Conv_2 = conv1x1(in_ch, out_ch) - - self.skip_rescale = skip_rescale - self.act = act - self.in_ch = in_ch - self.out_ch = out_ch - - def forward(self, x, temb=None): - h = self.act(self.GroupNorm_0(x)) - - if self.up: - if self.fir: - h = upsample_2d(h, self.fir_kernel, factor=2) - x = upsample_2d(x, self.fir_kernel, factor=2) - else: - h = naive_upsample_2d(h, factor=2) - x = naive_upsample_2d(x, factor=2) - elif self.down: - if self.fir: - h = downsample_2d(h, self.fir_kernel, factor=2) - x = downsample_2d(x, self.fir_kernel, factor=2) - else: - h = naive_downsample_2d(h, factor=2) - x = naive_downsample_2d(x, factor=2) - - h = self.Conv_0(h) - # Add bias to each feature map conditioned on the time embedding - if temb is not None: - h += self.Dense_0(self.act(temb))[:, :, None, None] - h = self.act(self.GroupNorm_1(h)) - h = self.Dropout_0(h) - h = self.Conv_1(h) - - if self.in_ch != self.out_ch or self.up or self.down: - x = self.Conv_2(x) - - if not self.skip_rescale: - return x + h - else: - return (x + h) / np.sqrt(2.0) +# class ResnetBlockDDPMpp(nn.Module): +# """ResBlock adapted from DDPM.""" +# +# def __init__( +# self, +# act, +# in_ch, +# out_ch=None, +# temb_dim=None, +# conv_shortcut=False, +# dropout=0.1, +# skip_rescale=False, +# init_scale=0.0, +# ): +# super().__init__() +# out_ch = out_ch if out_ch else in_ch +# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) +# self.Conv_0 = conv3x3(in_ch, out_ch) +# if temb_dim is not None: +# self.Dense_0 = nn.Linear(temb_dim, out_ch) +# self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) +# nn.init.zeros_(self.Dense_0.bias) +# self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) +# self.Dropout_0 = nn.Dropout(dropout) +# self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) +# if in_ch != out_ch: +# if conv_shortcut: +# self.Conv_2 = conv3x3(in_ch, out_ch) +# else: +# self.NIN_0 = NIN(in_ch, out_ch) +# +# self.skip_rescale = skip_rescale +# self.act = act +# self.out_ch = out_ch +# self.conv_shortcut = conv_shortcut +# +# def forward(self, x, temb=None): +# h = self.act(self.GroupNorm_0(x)) +# h = self.Conv_0(h) +# if temb is not None: +# h += self.Dense_0(self.act(temb))[:, :, None, None] +# h = self.act(self.GroupNorm_1(h)) +# h = self.Dropout_0(h) +# h = self.Conv_1(h) +# if x.shape[1] != self.out_ch: +# if self.conv_shortcut: +# x = self.Conv_2(x) +# else: +# x = self.NIN_0(x) +# if not self.skip_rescale: +# return x + h +# else: +# return (x + h) / np.sqrt(2.0) + + +# class ResnetBlockBigGANpp(nn.Module): +# def __init__( +# self, +# act, +# in_ch, +# out_ch=None, +# temb_dim=None, +# up=False, +# down=False, +# dropout=0.1, +# fir=False, +# fir_kernel=(1, 3, 3, 1), +# skip_rescale=True, +# init_scale=0.0, +# ): +# super().__init__() +# +# out_ch = out_ch if out_ch else in_ch +# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) +# self.up = up +# self.down = down +# self.fir = fir +# self.fir_kernel = fir_kernel +# +# self.Conv_0 = conv3x3(in_ch, out_ch) +# if temb_dim is not None: +# self.Dense_0 = nn.Linear(temb_dim, out_ch) +# self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) +# nn.init.zeros_(self.Dense_0.bias) +# +# self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) +# self.Dropout_0 = nn.Dropout(dropout) +# self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) +# if in_ch != out_ch or up or down: +# self.Conv_2 = conv1x1(in_ch, out_ch) +# +# self.skip_rescale = skip_rescale +# self.act = act +# self.in_ch = in_ch +# self.out_ch = out_ch +# +# def forward(self, x, temb=None): +# h = self.act(self.GroupNorm_0(x)) +# +# if self.up: +# if self.fir: +# h = upsample_2d(h, self.fir_kernel, factor=2) +# x = upsample_2d(x, self.fir_kernel, factor=2) +# else: +# h = naive_upsample_2d(h, factor=2) +# x = naive_upsample_2d(x, factor=2) +# elif self.down: +# if self.fir: +# h = downsample_2d(h, self.fir_kernel, factor=2) +# x = downsample_2d(x, self.fir_kernel, factor=2) +# else: +# h = naive_downsample_2d(h, factor=2) +# x = naive_downsample_2d(x, factor=2) +# +# h = self.Conv_0(h) +# Add bias to each feature map conditioned on the time embedding +# if temb is not None: +# h += self.Dense_0(self.act(temb))[:, :, None, None] +# h = self.act(self.GroupNorm_1(h)) +# h = self.Dropout_0(h) +# h = self.Conv_1(h) +# +# if self.in_ch != self.out_ch or self.up or self.down: +# x = self.Conv_2(x) +# +# if not self.skip_rescale: +# return x + h +# else: +# return (x + h) / np.sqrt(2.0) class NCSNpp(ModelMixin, ConfigMixin):