From 119a451d3ec939891a221e2709cae7e006d1b4d9 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 7 Jun 2023 11:32:03 +0200 Subject: [PATCH 001/181] initial --- src/diffusers/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 02907075345e..499a0115696c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -38,6 +38,7 @@ AutoencoderKL, ControlNetModel, ModelMixin, + PaellaVQModel, PriorTransformer, T5FilmDecoder, Transformer2DModel, From 0623199103acc3a8155d37e1881e0e682c5dad4e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 7 Jun 2023 11:32:26 +0200 Subject: [PATCH 002/181] initial --- src/diffusers/models/__init__.py | 2 +- src/diffusers/models/resnet.py | 80 +++++++++++++++ src/diffusers/models/unet_2d_blocks.py | 11 +++ src/diffusers/models/unet_2d_condition.py | 43 ++++++++ src/diffusers/models/vae.py | 2 +- src/diffusers/models/vq_model.py | 115 +++++++++++++++++++++- src/diffusers/utils/dummy_pt_objects.py | 15 +++ 7 files changed, 265 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 23839c84af45..99b6b8bd676f 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -27,7 +27,7 @@ from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel from .unet_3d_condition import UNet3DConditionModel - from .vq_model import VQModel + from .vq_model import PaellaVQModel, VQModel if is_flax_available(): from .controlnet_flax import FlaxControlNetModel diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 52f01552c528..dd4e876b6859 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -673,6 +673,86 @@ def forward(self, x): return x +class GlobalResponseNorm(nn.Module): + "Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" + + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, inputs): + Gx = torch.norm(inputs, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (inputs * Nx) + self.beta + inputs + + +class GlobalResponseResidualBlock(nn.Module): + def __init__(self, inp_channels, channel_skip=None, kernel_size=3, dropout=0.0) -> None: + super().__init__() + + # depthwise + self.depthwise = nn.Conv2d( + inp_channels + channel_skip, + inp_channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + groups=inp_channels, + ) + self.norm = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6) + + # channelwise + self.channelwise = nn.Sequential( + nn.Linear(inp_channels, inp_channels * 4), + nn.GELU(), + GlobalResponseNorm(inp_channels * 4), + nn.Dropout(dropout), + nn.Linear(inp_channels * 4, inp_channels), + ) + + def _norm(self, x, norm): + return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + def forward(self, inputs, inputs_skip=None): + inputs_res = inputs + if inputs_skip is not None: + inputs = torch.cat([inputs, inputs_skip], dim=1) + inputs = self._norm(self.depthwise(inputs), self.norm).permute(0, 2, 3, 1) + inputs = self.channelwise(inputs).permute(0, 3, 1, 2) + return inputs + inputs_res + + +class MixingResidualBlock(nn.Module): + def __init__(self, inp_channels, c_hidden): + super().__init__() + # depthwise + self.norm1 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6) + self.depthwise = nn.Sequential( + nn.ReplicationPad2d(1), nn.Conv2d(inp_channels, inp_channels, kernel_size=3, groups=inp_channels) + ) + + # channelwise + self.norm2 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(inp_channels, c_hidden), + nn.GELU(), + nn.Linear(c_hidden, inp_channels), + ) + + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) + + def _norm(self, x, norm): + return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + def forward(self, x): + mods = self.gammas + x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] + x = x + self.depthwise(x_temp) * mods[2] + x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] + x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] + return x + + # unet_rl.py class ResidualTemporalBlock1D(nn.Module): def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 674e58d7180e..2dc5a076da31 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -390,6 +390,17 @@ def get_up_block( raise ValueError(f"{up_block_type} does not exist.") +def get_paella_block(block_type, c_hidden, nhead, c_cond, c_r, kernel_size=3, c_skip=0, dropout=0, self_attn=True): + if block_type == "C": + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) + elif block_type == "A": + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) + elif block_type == "T": + return TimestepBlock(c_hidden, c_r) + else: + raise ValueError(f"'Block type {block_type} not supported.") + + class UNetMidBlock2D(nn.Module): def __init__( self, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index dda21fd80479..620916a08560 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -58,6 +58,49 @@ class UNet2DConditionOutput(BaseOutput): sample: torch.FloatTensor +class PaellaUNet2dConditionalModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + c_in=256, + c_out=256, + num_labels=8192, + c_r=64, + patch_size=2, + c_cond=1024, + c_hidden=[640, 1280, 1280], + nhead=[-1, 16, 16], + blocks=[6, 16, 6], + level_config=["CT", "CTA", "CTA"], + clip_embd=1024, + byt5_embd=1536, + clip_seq_len=4, + kernel_size=3, + dropout=0.1, + self_attn=True, + ): + super().__init__() + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + + # CONDITIONING + self.byt5_mapper = nn.Linear(byt5_embd, c_cond) + self.clip_mapper = nn.Linear(clip_embd, c_cond * clip_seq_len) + self.clip_image_mapper = nn.Linear(clip_embd, c_cond * clip_seq_len) + self.seq_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) + + self.in_mapper = nn.Sequential( + nn.Embedding(num_labels, c_in), nn.LayerNorm(c_in, elementwise_affine=False, eps=1e-6) + ) + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1), + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), + ) + + class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index dd4af0efcfd9..1c485afc1659 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -51,7 +51,7 @@ def __init__( super().__init__() self.layers_per_block = layers_per_block - self.conv_in = torch.nn.Conv2d( + self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=3, diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index 73158294ee6e..42bb6b887ce6 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -20,6 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .modeling_utils import ModelMixin +from .resnet import MixingResidualBlock from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer @@ -36,6 +37,118 @@ class VQEncoderOutput(BaseOutput): latents: torch.FloatTensor +class PaellaVQModel(ModelMixin, ConfigMixin): + r"""VQ-VAE model from Paella model. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + downscale_factor (int, *optional*, defaults to 2): Downscale factor of the input image. + levels (int, *optional*, defaults to 2): Number of levels in the model. + bottleneck_blocks (int, *optional*, defaults to 12): Number of bottleneck blocks in the model. + c_hidden (int, *optional*, defaults to 384): Number of hidden channels in the model. + c_latent (int, *optional*, defaults to 4): Number of latent channels in the model. + codebook_size (int, *optional*, defaults to 8192): Number of codebook vectors in the VQ-VAE. + scale_factor (float, *optional*, defaults to 0.3764): Scaling factor of the latent space. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_down_scale_factor: int = 2, + levels: int = 2, + bottleneck_blocks: int = 12, + c_hidden: int = 384, + c_latent: int = 4, + codebook_size: int = 8192, + scale_factor: float = 0.3764, + ): + super().__init__() + + c_levels = [c_hidden // (2**i) for i in reversed(range(levels))] + self.in_block = nn.Sequential( + nn.PixelUnshuffle(up_down_scale_factor), + nn.Conv2d(in_channels * up_down_scale_factor**2, c_levels[0], kernel_size=1), + ) + + down_blocks = [] + for i in range(levels): + if i > 0: + down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) + block = MixingResidualBlock(c_levels[i], c_levels[i] * 4) + down_blocks.append(block) + down_blocks.append( + nn.Sequential( + nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 + ) + ) + self.down_blocks = nn.Sequential(*down_blocks) + self.vquantizer = VectorQuantizer(codebook_size, vq_embed_dim=c_latent, legacy=False, beta=0.25) + + # Decoder blocks + up_blocks = [nn.Sequential(nn.Conv2d(c_latent, c_levels[-1], kernel_size=1))] + for i in range(levels): + for j in range(bottleneck_blocks if i == 0 else 1): + block = MixingResidualBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) + up_blocks.append(block) + if i < levels - 1: + up_blocks.append( + nn.ConvTranspose2d( + c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1 + ) + ) + self.up_blocks = nn.Sequential(*up_blocks) + self.out_block = nn.Sequential( + nn.Conv2d(c_levels[0], out_channels * up_down_scale_factor**2, kernel_size=1), + nn.PixelShuffle(up_down_scale_factor), + ) + + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: + h = self.in_block(x) + h = self.down_blocks(h) + + if not return_dict: + return (h,) + + return VQEncoderOutput(latents=h) + + def decode( + self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + if not force_not_quantize: + quant, _, _ = self.quantize(h) + else: + quant = h + x = self.up_blocks(quant) + dec = self.out_block(x) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + h = self.encode(x).latents + dec = self.decode(h).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + class VQModel(ModelMixin, ConfigMixin): r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray Kavukcuoglu. @@ -130,7 +243,7 @@ def decode( ) -> Union[DecoderOutput, torch.FloatTensor]: # also go through quantization layer if not force_not_quantize: - quant, emb_loss, info = self.quantize(h) + quant, _, _ = self.quantize(h) else: quant = h quant2 = self.post_quant_conv(quant) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 7a13bc89e883..b31a68bdaea5 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -167,6 +167,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class PaellaVQModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def get_constant_schedule(*args, **kwargs): requires_backends(get_constant_schedule, ["torch"]) From 8a6a92c77a92750b2e146c1a4c40b4949db2d100 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 21 Jun 2023 11:45:28 +0200 Subject: [PATCH 003/181] added initial convert script for paella vqmodel --- scripts/convert_paella.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 scripts/convert_paella.py diff --git a/scripts/convert_paella.py b/scripts/convert_paella.py new file mode 100644 index 000000000000..184f3c88e2f0 --- /dev/null +++ b/scripts/convert_paella.py @@ -0,0 +1,33 @@ +import argparse +import inspect +import os + +import numpy as np +import torch +import torch.nn as nn + +from diffusers import PaellaVQModel +from src.vqgan import VQModel +from src.modules import Paella + +model_path = "models/" +device = "cpu" + +paella_vqmodel = VQModel() +state_dict = torch.load(os.path.join(model_path, "vqgan_f4.pt"), map_location=device) +paella_vqmodel.load_state_dict(state_dict) + +state_dict["vquantizer.embedding.weight"] = state_dict["vquantizer.codebook.weight"] +state_dict.pop("vquantizer.codebook.weight") +vqmodel = PaellaVQModel( + codebook_size=paella_vqmodel.codebook_size, + c_latent=paella_vqmodel.c_latent, +) +vqmodel.load_state_dict(state_dict) + +# test vqmodel outputs match paella_vqmodel outputs + + +state_dict = torch.load(os.path.join(model_path, "paella_v3.pt"), map_location=device) +paella_model = Paella(byt5_embd=2560).to(device) +paella_model.load_state_dict(state_dict) From 80713a4de35d0ea6054c83d0090996946d823014 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 22 Jun 2023 10:52:39 +0200 Subject: [PATCH 004/181] initial wuerstchen pipeline --- docs/source/en/api/models.mdx | 3 + docs/source/en/api/pipelines/wuerstchen.mdx | 1 + ...onvert_paella.py => convert_wuerstchen.py} | 12 +-- src/diffusers/__init__.py | 1 + src/diffusers/models/resnet.py | 3 +- src/diffusers/models/vq_model.py | 2 +- src/diffusers/pipelines/__init__.py | 1 + .../pipelines/wuerstchen/__init__.py | 5 + .../wuerstchen/pipeline_wuerstchen.py | 92 +++++++++++++++++++ 9 files changed, 111 insertions(+), 9 deletions(-) create mode 100644 docs/source/en/api/pipelines/wuerstchen.mdx rename scripts/{convert_paella.py => convert_wuerstchen.py} (75%) create mode 100644 src/diffusers/pipelines/wuerstchen/__init__.py create mode 100644 src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py diff --git a/docs/source/en/api/models.mdx b/docs/source/en/api/models.mdx index 74291f9173ea..956d2ce706c2 100644 --- a/docs/source/en/api/models.mdx +++ b/docs/source/en/api/models.mdx @@ -52,6 +52,9 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## VQModel [[autodoc]] VQModel +## PaellaVQModel +[[autodoc]] PaellaVQModel + ## AutoencoderKLOutput [[autodoc]] models.autoencoder_kl.AutoencoderKLOutput diff --git a/docs/source/en/api/pipelines/wuerstchen.mdx b/docs/source/en/api/pipelines/wuerstchen.mdx new file mode 100644 index 000000000000..23b5cc5f9f5c --- /dev/null +++ b/docs/source/en/api/pipelines/wuerstchen.mdx @@ -0,0 +1 @@ +# Würstchen \ No newline at end of file diff --git a/scripts/convert_paella.py b/scripts/convert_wuerstchen.py similarity index 75% rename from scripts/convert_paella.py rename to scripts/convert_wuerstchen.py index 184f3c88e2f0..3230cffd594e 100644 --- a/scripts/convert_paella.py +++ b/scripts/convert_wuerstchen.py @@ -1,20 +1,17 @@ -import argparse -import inspect import os -import numpy as np import torch -import torch.nn as nn from diffusers import PaellaVQModel -from src.vqgan import VQModel -from src.modules import Paella +from modules import Paella +from vqgan import VQModel + model_path = "models/" device = "cpu" paella_vqmodel = VQModel() -state_dict = torch.load(os.path.join(model_path, "vqgan_f4.pt"), map_location=device) +state_dict = torch.load(os.path.join(model_path, "vqgan_f4_v1_500k.pt"), map_location=device)["state_dict"] paella_vqmodel.load_state_dict(state_dict) state_dict["vquantizer.embedding.weight"] = state_dict["vquantizer.codebook.weight"] @@ -23,6 +20,7 @@ codebook_size=paella_vqmodel.codebook_size, c_latent=paella_vqmodel.c_latent, ) + vqmodel.load_state_dict(state_dict) # test vqmodel outputs match paella_vqmodel outputs diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 499a0115696c..8c4a03d0c63a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -175,6 +175,7 @@ VersatileDiffusionPipeline, VersatileDiffusionTextToImagePipeline, VQDiffusionPipeline, + WuerstchenPipeline, ) try: diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index dd4e876b6859..e7c1acd9479f 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -710,7 +710,8 @@ def __init__(self, inp_channels, channel_skip=None, kernel_size=3, dropout=0.0) nn.Linear(inp_channels * 4, inp_channels), ) - def _norm(self, x, norm): + @staticmethod + def _norm(x, norm): return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) def forward(self, inputs, inputs_skip=None): diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index 42bb6b887ce6..f56acfdf8e94 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -46,7 +46,7 @@ class PaellaVQModel(ModelMixin, ConfigMixin): Parameters: in_channels (int, *optional*, defaults to 3): Number of channels in the input image. out_channels (int, *optional*, defaults to 3): Number of channels in the output. - downscale_factor (int, *optional*, defaults to 2): Downscale factor of the input image. + up_down_scale_factor (int, *optional*, defaults to 2): Up and Downscale factor of the input image. levels (int, *optional*, defaults to 2): Number of levels in the model. bottleneck_blocks (int, *optional*, defaults to 12): Number of bottleneck blocks in the model. c_hidden (int, *optional*, defaults to 384): Number of hidden channels in the model. diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b1650240848a..a831548eaf1d 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -99,6 +99,7 @@ VersatileDiffusionTextToImagePipeline, ) from .vq_diffusion import VQDiffusionPipeline + from .wuerstchen import WuerstchenPipeline try: if not is_onnx_available(): diff --git a/src/diffusers/pipelines/wuerstchen/__init__.py b/src/diffusers/pipelines/wuerstchen/__init__.py new file mode 100644 index 000000000000..1570e2f16659 --- /dev/null +++ b/src/diffusers/pipelines/wuerstchen/__init__.py @@ -0,0 +1,5 @@ +from ...utils import is_torch_available, is_transformers_available + + +if is_transformers_available() and is_torch_available(): + from .pipeline_wuerstchen import WuerstchenPipeline diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py new file mode 100644 index 000000000000..cfd03b6fee89 --- /dev/null +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -0,0 +1,92 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...utils import is_accelerate_available, logging +from ..pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import WuerstchenPipeline + + >>> pipe = WuerstchenPipeline.from_pretrained("kashif/wuerstchen", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> image = pipe(prompt).images[0] + ``` +""" + + +class WuerstchenPipeline(DiffusionPipeline): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + def __init__(self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, scheduler) -> None: + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=scheduler, + ) + self.register_to_config() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [ + self.text_encoder, + self.unet, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 100, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + ): + self.tokenizer.tokenize([prompt] * num_images_per_prompt) + + if negative_prompt: + clip_text_tokens_uncond = self.tokenizer([negative_prompt] * num_images_per_prompt) + self.text_encoder.get_input_embeddings()(clip_text_tokens_uncond["input_ids"]) From 8bd6cb85a390904b2b8e921f121de4b38a569ae9 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 22 Jun 2023 12:08:30 +0200 Subject: [PATCH 005/181] add LayerNorm2d --- scripts/convert_wuerstchen.py | 4 ++-- src/diffusers/models/resnet.py | 3 ++- src/diffusers/models/unet_2d_blocks.py | 6 +++--- src/diffusers/models/unet_2d_condition.py | 6 ++++++ 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 3230cffd594e..141c92b6bed5 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -1,11 +1,11 @@ import os import torch - -from diffusers import PaellaVQModel from modules import Paella from vqgan import VQModel +from diffusers import PaellaVQModel + model_path = "models/" device = "cpu" diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index e7c1acd9479f..1863a75957a8 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -742,7 +742,8 @@ def __init__(self, inp_channels, c_hidden): self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) - def _norm(self, x, norm): + @staticmethod + def _norm(x, norm): return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) def forward(self, x): diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 2dc5a076da31..6515d2a52ffc 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -22,7 +22,7 @@ from .attention import AdaGroupNorm from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from .dual_transformer_2d import DualTransformer2DModel -from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D +from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D, GlobalResponseResidualBlock from .transformer_2d import Transformer2DModel @@ -392,9 +392,9 @@ def get_up_block( def get_paella_block(block_type, c_hidden, nhead, c_cond, c_r, kernel_size=3, c_skip=0, dropout=0, self_attn=True): if block_type == "C": - return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) + return GlobalResponseResidualBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) elif block_type == "A": - return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) + return Attention(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) elif block_type == "T": return TimestepBlock(c_hidden, c_r) else: diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 620916a08560..fbfa5a1825df 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -57,6 +57,12 @@ class UNet2DConditionOutput(BaseOutput): sample: torch.FloatTensor +class LayerNorm2d(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) class PaellaUNet2dConditionalModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): _supports_gradient_checkpointing = True From 806ed12361c7042a23f13f36afd0a1136c7b52f7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 22 Jun 2023 15:00:38 +0200 Subject: [PATCH 006/181] added modules --- scripts/convert_wuerstchen.py | 22 ++++++++++++++----- src/diffusers/models/unet_2d_blocks.py | 13 +++++++++-- src/diffusers/pipelines/wuerstchen/modules.py | 19 ++++++++++++++++ .../wuerstchen/pipeline_wuerstchen.py | 15 ++++++++----- 4 files changed, 57 insertions(+), 12 deletions(-) create mode 100644 src/diffusers/pipelines/wuerstchen/modules.py diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 141c92b6bed5..03b3995b8a19 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -1,11 +1,16 @@ +import argparse +import inspect import os +import numpy as np import torch -from modules import Paella -from vqgan import VQModel +import torch.nn as nn from diffusers import PaellaVQModel +from transformers import CLIPTextModel, AutoTokenizer +from vqgan import VQModel +from modules import Paella, Prior model_path = "models/" device = "cpu" @@ -20,12 +25,19 @@ codebook_size=paella_vqmodel.codebook_size, c_latent=paella_vqmodel.c_latent, ) - vqmodel.load_state_dict(state_dict) +# TODO: test vqmodel outputs match paella_vqmodel outputs -# test vqmodel outputs match paella_vqmodel outputs +# Clip Text encoder and tokenizer +text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") +clip_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") +# EfficientNet -state_dict = torch.load(os.path.join(model_path, "paella_v3.pt"), map_location=device) +# Paella +state_dict = torch.load(os.path.join(model_path, "model_stage_b.pt"), map_location=device)t['state_dict'] paella_model = Paella(byt5_embd=2560).to(device) paella_model.load_state_dict(state_dict) + +# Prior +prior_model = Prior(c_in=16, c=1536, c_cond=1024, c_r=64, depth=32, nhead=24).to(device) \ No newline at end of file diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 6515d2a52ffc..8d8d78056adf 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -22,7 +22,16 @@ from .attention import AdaGroupNorm from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from .dual_transformer_2d import DualTransformer2DModel -from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D, GlobalResponseResidualBlock +from .resnet import ( + Downsample2D, + FirDownsample2D, + FirUpsample2D, + KDownsample2D, + KUpsample2D, + ResnetBlock2D, + Upsample2D, + GlobalResponseResidualBlock, +) from .transformer_2d import Transformer2DModel @@ -394,7 +403,7 @@ def get_paella_block(block_type, c_hidden, nhead, c_cond, c_r, kernel_size=3, c_ if block_type == "C": return GlobalResponseResidualBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) elif block_type == "A": - return Attention(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) elif block_type == "T": return TimestepBlock(c_hidden, c_r) else: diff --git a/src/diffusers/pipelines/wuerstchen/modules.py b/src/diffusers/pipelines/wuerstchen/modules.py new file mode 100644 index 000000000000..06d95a17fe5b --- /dev/null +++ b/src/diffusers/pipelines/wuerstchen/modules.py @@ -0,0 +1,19 @@ +import torch.nn as nn +from torchvision.models import efficientnet_v2_s, efficientnet_v2_l + + +class EfficientNetEncoder(nn.Module): + def __init__(self, c_latent=16, effnet="efficientnet_v2_s"): + super().__init__() + if effnet == "efficientnet_v2_s": + self.backbone = efficientnet_v2_s(weights="DEFAULT").features.eval() + else: + print("Using EffNet L.") + self.backbone = efficientnet_v2_l(weights="DEFAULT").features.eval() + self.mapper = nn.Sequential( + nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 + ) + + def forward(self, x): + return self.mapper(self.backbone(x)) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index cfd03b6fee89..01dee82c6f8c 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -14,10 +14,11 @@ from typing import List, Optional, Union import torch -from transformers import T5EncoderModel, T5Tokenizer +from transformers import CLIPTextModel, AutoTokenizer from ...utils import is_accelerate_available, logging from ..pipeline_utils import DiffusionPipeline +from ...models import PaellaVQModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -38,15 +39,19 @@ class WuerstchenPipeline(DiffusionPipeline): - tokenizer: T5Tokenizer - text_encoder: T5EncoderModel + clip_tokenizer: AutoTokenizer + text_encoder: CLIPTextModel + vqmodel: PaellaVQModel - def __init__(self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, scheduler) -> None: + def __init__( + self, clip_tokenizer: AutoTokenizer, text_encoder: CLIPTextModel, vqmodel: PaellaVQModel, scheduler + ) -> None: super().__init__() self.register_modules( - tokenizer=tokenizer, + clip_tokenizer=clip_tokenizer, text_encoder=text_encoder, + vqmodel=vqmodel, scheduler=scheduler, ) self.register_to_config() From ff6139d161307da5565b978fbd211cf3adff77ce Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 22 Jun 2023 15:32:53 +0200 Subject: [PATCH 007/181] fix typo --- scripts/convert_wuerstchen.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 03b3995b8a19..46b6fb6d199e 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -33,11 +33,14 @@ clip_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") # EfficientNet +state_dict = torch.load(os.path.join(model_path, "model_stage_b.pt"), map_location=device)["effnet_state_dict"] # Paella -state_dict = torch.load(os.path.join(model_path, "model_stage_b.pt"), map_location=device)t['state_dict'] -paella_model = Paella(byt5_embd=2560).to(device) +state_dict = torch.load(os.path.join(model_path, "model_stage_b.pt"), map_location=device)["state_dict"] +paella_model = Paella(byt5_embd=1024).to(device) paella_model.load_state_dict(state_dict) # Prior -prior_model = Prior(c_in=16, c=1536, c_cond=1024, c_r=64, depth=32, nhead=24).to(device) \ No newline at end of file +state_dict = torch.load(os.path.join(model_path, "model_stage_c.pt"), map_location=device) +prior_model = Prior(c_in=16, c=1536, c_cond=1024, c_r=64, depth=32, nhead=24).to(device) +prior_model.load_state_dict(state_dict["state_dict"]) From 560da3b993394f72f1eab2e8c3163650adcb1605 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 23 Jun 2023 10:19:27 +0200 Subject: [PATCH 008/181] use model_v2 --- scripts/convert_wuerstchen.py | 53 +++++++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 46b6fb6d199e..a8bbbf5872d7 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -33,7 +33,7 @@ clip_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") # EfficientNet -state_dict = torch.load(os.path.join(model_path, "model_stage_b.pt"), map_location=device)["effnet_state_dict"] +state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device)["effnet_state_dict"] # Paella state_dict = torch.load(os.path.join(model_path, "model_stage_b.pt"), map_location=device)["state_dict"] @@ -41,6 +41,55 @@ paella_model.load_state_dict(state_dict) # Prior -state_dict = torch.load(os.path.join(model_path, "model_stage_c.pt"), map_location=device) +state_dict = torch.load(os.path.join(model_path, "model_v2__stage_c.pt"), map_location=device) prior_model = Prior(c_in=16, c=1536, c_cond=1024, c_r=64, depth=32, nhead=24).to(device) prior_model.load_state_dict(state_dict["state_dict"]) + + +# scheduler +scheduler = DDPMScheduler( + beta_schedule="linear", +) + +# WuerstchenPipeline( +# vae=VQGan() +# text_encoder=ClipTextEncoder(), +# prior=prior, +# (image_encoder)=efficient_net, +# ) +# stage C = prior +# stage B = unet +# stage A = vae +# WuerstchenPipeline( +# vae=VQGan() +# text_encoder=ClipTextEncoder(), +# unet = UNet2DConditionModel(), +# prior=prior, +# (image_encoder)=efficient_net, +# ) +# Patrick von Platen4:17 PM +# WuerstchenPipeline( +# vae=VQGan() +# text_encoder=ClipTextEncoder(), +# unet = UNet2DConditionModel(), +# prior=prior, +# tokenizer=CLIPTokenizer, +# (image_encoder)=efficient_net, +# ) +# WuerstchenPipeline( +# vae=VQGan() +# text_encoder=ClipTextEncoder(), +# unet = UNet2DConditionModel(), +# prior=PriorTransformer(), +# tokenizer=CLIPTokenizer, +# (image_encoder)=efficient_net, +# ) +# Patrick von Platen4:20 PM +# WuerstchenPipeline( +# vae=VQGan() +# text_encoder=ClipTextEncoder(), +# unet = NewUNet(), # Paella Style +# prior=NewPrior(), # find good name +# tokenizer=CLIPTokenizer, +# (image_encoder)=efficient_net, +# ) From 3acc9fa691caeedc3a770d8dc9f5d125f173db72 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 23 Jun 2023 10:26:48 +0200 Subject: [PATCH 009/181] embed clip caption amd negative_caption --- .../wuerstchen/pipeline_wuerstchen.py | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 01dee82c6f8c..e61ca33842c9 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -90,8 +90,23 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, ): - self.tokenizer.tokenize([prompt] * num_images_per_prompt) + clip_tokens = self.tokenizer( + [prompt] * num_images_per_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + clip_text_embeddings = self.text_encoder(**clip_tokens).last_hidden_state + + if negative_prompt is None: + negative_prompt = "" - if negative_prompt: - clip_text_tokens_uncond = self.tokenizer([negative_prompt] * num_images_per_prompt) - self.text_encoder.get_input_embeddings()(clip_text_tokens_uncond["input_ids"]) + clip_text_tokens_uncond = self.tokenizer( + [negative_prompt] * num_images_per_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + clip_text_embeddings_uncond = self.text_encoder(**clip_tokens_uncond).last_hidden_state From f84ac09e984a42bf6867c5e87f34450bcd0685f1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 23 Jun 2023 10:27:44 +0200 Subject: [PATCH 010/181] fixed name of var --- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index e61ca33842c9..bc52a8c1ecfa 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -102,7 +102,7 @@ def __call__( if negative_prompt is None: negative_prompt = "" - clip_text_tokens_uncond = self.tokenizer( + clip_tokens_uncond = self.tokenizer( [negative_prompt] * num_images_per_prompt, truncation=True, padding="max_length", From 25de2c68c6f5d3cf3726fb2424f9cc388f966513 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 23 Jun 2023 11:21:18 +0200 Subject: [PATCH 011/181] initial modules in one place --- src/diffusers/pipelines/wuerstchen/modules.py | 385 +++++++++++++++++- .../wuerstchen/pipeline_wuerstchen.py | 8 +- 2 files changed, 388 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/modules.py b/src/diffusers/pipelines/wuerstchen/modules.py index 06d95a17fe5b..bfd358f3d4df 100644 --- a/src/diffusers/pipelines/wuerstchen/modules.py +++ b/src/diffusers/pipelines/wuerstchen/modules.py @@ -1,5 +1,109 @@ +import math + +import numpy as np +import torch import torch.nn as nn -from torchvision.models import efficientnet_v2_s, efficientnet_v2_l +from torchvision.models import efficientnet_v2_l, efficientnet_v2_s + + +class LayerNorm2d(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + +class TimestepBlock(nn.Module): + def __init__(self, c, c_timestep): + super().__init__() + self.mapper = nn.Linear(c_timestep, c * 2) + + def forward(self, x, t): + a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1) + return x * (1 + a) + b + + +class Attention2D(nn.Module): + def __init__(self, c, nhead, dropout=0.0): + super().__init__() + self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) + + def forward(self, x, kv, self_attn=False): + orig_shape = x.shape + x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 + if self_attn: + kv = torch.cat([x, kv], dim=1) + x = self.attn(x, kv, kv, need_weights=False)[0] + x = x.permute(0, 2, 1).view(*orig_shape) + return x + + +class ResBlockStageB(nn.Module): + def __init__(self, c, c_skip=None, kernel_size=3, dropout=0.0): + super().__init__() + self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c + c_skip, c * 4), + nn.GELU(), + GlobalResponseNorm(c * 4), + nn.Dropout(dropout), + nn.Linear(c * 4, c), + ) + + def forward(self, x, x_skip=None): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res + + +class ResBlock(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): + super().__init__() + self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c) + ) + + def forward(self, x, x_skip=None): + x_res = x + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1) + x = self.channelwise(x).permute(0, 3, 1, 2) + return x + x_res + + +# from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 +class GlobalResponseNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +class AttnBlock(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention2D(c, nhead, dropout) + self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c)) + + def forward(self, x, kv): + kv = self.kv_mapper(kv) + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + return x class EfficientNetEncoder(nn.Module): @@ -17,3 +121,282 @@ def __init__(self, c_latent=16, effnet="efficientnet_v2_s"): def forward(self, x): return self.mapper(self.backbone(x)) + + +class Prior(nn.Module): + def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, latent_size=(12, 12), dropout=0.1): + super().__init__() + self.c_r = c_r + self.projection = nn.Conv2d(c_in, c, kernel_size=1) + self.cond_mapper = nn.Sequential( + nn.Linear(c_cond, c), + nn.LeakyReLU(0.2), + nn.Linear(c, c), + ) + + self.blocks = nn.ModuleList() + for _ in range(depth): + self.blocks.append(ResBlock(c, dropout=dropout)) + self.blocks.append(TimestepBlock(c, c_r)) + self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout)) + self.out = nn.Sequential( + LayerNorm2d(c, elementwise_affine=False, eps=1e-6), + nn.Conv2d(c, c_in * 2, kernel_size=1), + ) + + self.apply(self._init_weights) # General init + nn.init.normal_(self.projection.weight, std=0.02) # inputs + nn.init.normal_(self.cond_mapper[0].weight, std=0.02) # conditionings + nn.init.normal_(self.cond_mapper[-1].weight, std=0.02) # conditionings + nn.init.constant_(self.out[1].weight, 0) # outputs + + # blocks + for block in self.blocks: + if isinstance(block, ResBlock): + block.channelwise[-1].weight.data *= np.sqrt(1 / depth) + elif isinstance(block, TimestepBlock): + nn.init.constant_(block.mapper.weight, 0) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode="constant") + return emb + + def forward(self, x, r, c): + x_in = x + x = self.projection(x) + c_embed = self.cond_mapper(c) + r_embed = self.gen_r_embedding(r) + for block in self.blocks: + if isinstance(block, AttnBlock): + x = block(x, c_embed) + elif isinstance(block, TimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + a, b = self.out(x).chunk(2, dim=1) + # denoised = a / (1-(1-b).pow(2)).sqrt() + return (x_in - a) / ((1 - b).abs() + 1e-5) + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data * (1 - beta) + + +class DiffNeXt(nn.Module): + def __init__( + self, + c_in=4, + c_out=4, + c_r=64, + patch_size=2, + c_cond=1024, + c_hidden=[320, 640, 1280, 1280], + nhead=[-1, 10, 20, 20], + blocks=[4, 4, 14, 4], + level_config=["CT", "CTA", "CTA", "CTA"], + inject_effnet=[False, True, True, True], + effnet_embd=16, + clip_embd=1024, + kernel_size=3, + dropout=0.1, + self_attn=True, + ): + super().__init__() + self.c_r = c_r + self.c_cond = c_cond + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + + # CONDITIONING + self.clip_mapper = nn.Linear(clip_embd, c_cond) + self.effnet_mappers = nn.ModuleList( + [ + nn.Conv2d(effnet_embd, c_cond, kernel_size=1) if inject else None + for inject in inject_effnet + list(reversed(inject_effnet)) + ] + ) + self.seq_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1), + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0): + if block_type == "C": + return ResBlockStageB(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) + elif block_type == "A": + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) + elif block_type == "T": + return TimestepBlock(c_hidden, c_r) + else: + raise Exception(f"Block type {block_type} not supported") + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + for i in range(len(c_hidden)): + down_block = nn.ModuleList() + if i > 0: + down_block.append( + nn.Sequential( + LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2), + ) + ) + for _ in range(blocks[i]): + for block_type in level_config[i]: + c_skip = c_cond if inject_effnet[i] else 0 + down_block.append(get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i])) + self.down_blocks.append(down_block) + + # -- up blocks + self.up_blocks = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + up_block = nn.ModuleList() + for j in range(blocks[i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + c_skip += c_cond if inject_effnet[i] else 0 + up_block.append(get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i])) + if i > 0: + up_block.append( + nn.Sequential( + LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6), + nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2), + ) + ) + self.up_blocks.append(up_block) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[0], 2 * c_out * (patch_size**2), kernel_size=1), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + self.apply(self._init_weights) # General init + for mapper in self.effnet_mappers: + if mapper is not None: + nn.init.normal_(mapper.weight, std=0.02) # conditionings + nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings + nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + nn.init.constant_(self.clf[1].weight, 0) # outputs + + # blocks + for level_block in self.down_blocks + self.up_blocks: + for block in level_block: + if isinstance(block, ResBlockStageB): + block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks)) + elif isinstance(block, TimestepBlock): + nn.init.constant_(block.mapper.weight, 0) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode="constant") + return emb + + def gen_c_embeddings(self, clip): + clip = self.clip_mapper(clip) + clip = self.seq_norm(clip) + return clip + + def _down_encode(self, x, r_embed, effnet, clip): + level_outputs = [] + for i, down_block in enumerate(self.down_blocks): + effnet_c = None + for block in down_block: + if isinstance(block, ResBlockStageB): + if effnet_c is None and self.effnet_mappers[i] is not None: + effnet_c = self.effnet_mappers[i]( + nn.functional.interpolate( + effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True + ) + ) + skip = effnet_c if self.effnet_mappers[i] is not None else None + x = block(x, skip) + elif isinstance(block, AttnBlock): + x = block(x, clip) + elif isinstance(block, TimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, effnet, clip): + x = level_outputs[0] + for i, up_block in enumerate(self.up_blocks): + effnet_c = None + for j, block in enumerate(up_block): + if isinstance(block, ResBlockStageB): + if effnet_c is None and self.effnet_mappers[len(self.down_blocks) + i] is not None: + effnet_c = self.effnet_mappers[len(self.down_blocks) + i]( + nn.functional.interpolate( + effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True + ) + ) + skip = level_outputs[i] if j == 0 and i > 0 else None + if effnet_c is not None: + if skip is not None: + skip = torch.cat([skip, effnet_c], dim=1) + else: + skip = effnet_c + x = block(x, skip) + elif isinstance(block, AttnBlock): + x = block(x, clip) + elif isinstance(block, TimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + return x + + def forward(self, x, r, effnet, clip, x_cat=None, eps=1e-3, return_noise=True): + if x_cat is not None: + x = torch.cat([x, x_cat], dim=1) + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r) + clip = self.gen_c_embeddings(clip) + + # Model Blocks + x_in = x + x = self.embedding(x) + level_outputs = self._down_encode(x, r_embed, effnet, clip) + x = self._up_decode(level_outputs, r_embed, effnet, clip) + a, b = self.clf(x).chunk(2, dim=1) + b = b.sigmoid() * (1 - eps * 2) + eps + if return_noise: + return (x_in - a) / b + else: + return a, b + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data * (1 - beta) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index bc52a8c1ecfa..a86d8bc40753 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -14,11 +14,11 @@ from typing import List, Optional, Union import torch -from transformers import CLIPTextModel, AutoTokenizer +from transformers import AutoTokenizer, CLIPTextModel +from ...models import PaellaVQModel from ...utils import is_accelerate_available, logging from ..pipeline_utils import DiffusionPipeline -from ...models import PaellaVQModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -97,7 +97,7 @@ def __call__( max_length=self.tokenizer.model_max_length, return_tensors="pt", ) - clip_text_embeddings = self.text_encoder(**clip_tokens).last_hidden_state + self.text_encoder(**clip_tokens).last_hidden_state if negative_prompt is None: negative_prompt = "" @@ -109,4 +109,4 @@ def __call__( max_length=self.tokenizer.model_max_length, return_tensors="pt", ) - clip_text_embeddings_uncond = self.text_encoder(**clip_tokens_uncond).last_hidden_state + self.text_encoder(**clip_tokens_uncond).last_hidden_state From 30e41a5d54c36bc5e549375f9ff642ec3eb8b71f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 23 Jun 2023 13:26:56 +0200 Subject: [PATCH 012/181] WuerstchenPriorPipeline --- .../wuerstchen/pipeline_wuerstchen.py | 87 ++++++++++++------- 1 file changed, 56 insertions(+), 31 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index a86d8bc40753..9a4df3ecbd61 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -11,15 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from dataclasses import dataclass from typing import List, Optional, Union import torch -from transformers import AutoTokenizer, CLIPTextModel +from transformers import CLIPTokenizer, CLIPTextModel from ...models import PaellaVQModel -from ...utils import is_accelerate_available, logging +from ...utils import is_accelerate_available, logging, BaseOutput from ..pipeline_utils import DiffusionPipeline +from ...schedulers import DDPMScheduler +from .modules import DiffNeXt, Prior, EfficientNetEncoder logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -39,47 +43,62 @@ class WuerstchenPipeline(DiffusionPipeline): - clip_tokenizer: AutoTokenizer - text_encoder: CLIPTextModel + unet: DiffNeXt vqmodel: PaellaVQModel + +@dataclass +class WuerstchenPriorPipelineOutput(BaseOutput): + """ + Output class for WuerstchenPriorPipeline. + + Args: + image_embeds (`torch.FloatTensor`) + clip image embeddings for text prompt + negative_image_embeds (`List[PIL.Image.Image]` or `np.ndarray`) + clip image embeddings for unconditional tokens + """ + + image_embeds: Union[torch.FloatTensor, np.ndarray] + negative_image_embeds: Union[torch.FloatTensor, np.ndarray] + + +class WuerstchenPriorPipeline(DiffusionPipeline): + """ + Pipeline for generating image prior for Wuerstchen. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`Prior`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`DDPMScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + """ + def __init__( - self, clip_tokenizer: AutoTokenizer, text_encoder: CLIPTextModel, vqmodel: PaellaVQModel, scheduler + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + prior: Prior, + scheduler: DDPMScheduler, ) -> None: super().__init__() self.register_modules( - clip_tokenizer=clip_tokenizer, + tokenizer=tokenizer, text_encoder=text_encoder, - vqmodel=vqmodel, + prior=prior, scheduler=scheduler, ) self.register_to_config() - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's - models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only - when their specific submodule has its `forward` method called. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - models = [ - self.text_encoder, - self.unet, - ] - for cpu_offloaded_model in models: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - @torch.no_grad() def __call__( self, @@ -89,6 +108,12 @@ def __call__( guidance_scale: float = 7.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + guidance_scale: float = 4.0, + output_type: Optional[str] = "pt", # pt only + return_dict: bool = True, ): clip_tokens = self.tokenizer( [prompt] * num_images_per_prompt, From d56321898412797e2b59de13630aa6085929b695 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 23 Jun 2023 14:22:53 +0200 Subject: [PATCH 013/181] inital shape --- .../pipelines/wuerstchen/pipeline_wuerstchen.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 9a4df3ecbd61..f7f7271c3f63 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -15,6 +15,7 @@ from dataclasses import dataclass from typing import List, Optional, Union +import numpy as np import torch from transformers import CLIPTokenizer, CLIPTextModel @@ -108,13 +109,13 @@ def __call__( guidance_scale: float = 7.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, - num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, - guidance_scale: float = 4.0, output_type: Optional[str] = "pt", # pt only return_dict: bool = True, ): + do_classifier_free_guidance = guidance_scale > 1.0 + clip_tokens = self.tokenizer( [prompt] * num_images_per_prompt, truncation=True, @@ -122,7 +123,7 @@ def __call__( max_length=self.tokenizer.model_max_length, return_tensors="pt", ) - self.text_encoder(**clip_tokens).last_hidden_state + clip_text_embeddings = self.text_encoder(**clip_tokens).last_hidden_state if negative_prompt is None: negative_prompt = "" @@ -134,4 +135,6 @@ def __call__( max_length=self.tokenizer.model_max_length, return_tensors="pt", ) - self.text_encoder(**clip_tokens_uncond).last_hidden_state + clip_text_embeddings_uncond = self.text_encoder(**clip_tokens_uncond).last_hidden_state + + effnet_features_shape = (num_images_per_prompt, 16, 24, 24) From d328459fe4e003ef62d486548e43d03a762b20bf Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 23 Jun 2023 15:08:10 +0200 Subject: [PATCH 014/181] initial denoising prior loop --- .../wuerstchen/pipeline_wuerstchen.py | 51 ++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index f7f7271c3f63..4d35ce6a8fda 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -20,7 +20,7 @@ from transformers import CLIPTokenizer, CLIPTextModel from ...models import PaellaVQModel -from ...utils import is_accelerate_available, logging, BaseOutput +from ...utils import is_accelerate_available, logging, BaseOutput, randn_tensor from ..pipeline_utils import DiffusionPipeline from ...schedulers import DDPMScheduler @@ -100,6 +100,17 @@ def __init__( ) self.register_to_config() + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + @torch.no_grad() def __call__( self, @@ -138,3 +149,41 @@ def __call__( clip_text_embeddings_uncond = self.text_encoder(**clip_tokens_uncond).last_hidden_state effnet_features_shape = (num_images_per_prompt, 16, 24, 24) + + device = "cuda" + + self.scheduler.set_timesteps(num_inference_steps, device=device) + prior_timesteps_tensor = self.scheduler.timesteps + + latents = self.prepare_latents( + effnet_features_shape, + clip_text_embeddings.dtype, + device, + generator, + latents, + self.scheduler, + ) + + cond = torch.cat([clip_text_embeddings, clip_text_embeddings_uncond]) + + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # x, r, c + predicted_image_embedding = self.prior(latents, r=t, c=cond) + + if i + 1 == prior_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = prior_timesteps_tensor[i + 1] + + latents = self.scheduler.step( + predicted_image_embedding, + timestep=t, + sample=latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + if not return_dict: + return (latents, clip_text_embeddings, clip_text_embeddings_uncond) + + return WuerstchenPriorPipelineOutput(latents, clip_text_embeddings, clip_text_embeddings_uncond) From f0cc379ff26cb469d3db8106a2ab417fbaca7b79 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 23 Jun 2023 17:11:58 +0200 Subject: [PATCH 015/181] fix output --- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 4d35ce6a8fda..9b3107254765 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -61,7 +61,8 @@ class WuerstchenPriorPipelineOutput(BaseOutput): """ image_embeds: Union[torch.FloatTensor, np.ndarray] - negative_image_embeds: Union[torch.FloatTensor, np.ndarray] + text_embeds: Union[torch.FloatTensor, np.ndarray] + negative_text_embeds: Union[torch.FloatTensor, np.ndarray] class WuerstchenPriorPipeline(DiffusionPipeline): @@ -183,6 +184,9 @@ def __call__( prev_timestep=prev_timestep, ).prev_sample + # normalize the latents + latent = latent * 42.0 - 1.0 + if not return_dict: return (latents, clip_text_embeddings, clip_text_embeddings_uncond) From 4c8a7918334bde96d9dfa67d52e6e6482284f78f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 23 Jun 2023 17:13:39 +0200 Subject: [PATCH 016/181] add WuerstchenPriorPipeline to __init__.py --- src/diffusers/__init__.py | 1 + src/diffusers/pipelines/__init__.py | 2 +- src/diffusers/pipelines/wuerstchen/__init__.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8c4a03d0c63a..26874f53d225 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -176,6 +176,7 @@ VersatileDiffusionTextToImagePipeline, VQDiffusionPipeline, WuerstchenPipeline, + WuerstchenPriorPipeline, ) try: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index a831548eaf1d..1dee303b5437 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -99,7 +99,7 @@ VersatileDiffusionTextToImagePipeline, ) from .vq_diffusion import VQDiffusionPipeline - from .wuerstchen import WuerstchenPipeline + from .wuerstchen import WuerstchenPipeline, WuerstchenPriorPipeline try: if not is_onnx_available(): diff --git a/src/diffusers/pipelines/wuerstchen/__init__.py b/src/diffusers/pipelines/wuerstchen/__init__.py index 1570e2f16659..1bbbcfa23826 100644 --- a/src/diffusers/pipelines/wuerstchen/__init__.py +++ b/src/diffusers/pipelines/wuerstchen/__init__.py @@ -2,4 +2,4 @@ if is_transformers_available() and is_torch_available(): - from .pipeline_wuerstchen import WuerstchenPipeline + from .pipeline_wuerstchen import WuerstchenPipeline, WuerstchenPriorPipeline From ad474b15bccc386cdd909e4609a65a346cd9b5fb Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 23 Jun 2023 18:06:07 +0200 Subject: [PATCH 017/181] use the noise ratio in the Prior --- scripts/convert_wuerstchen.py | 21 +++++++++++-------- .../wuerstchen/pipeline_wuerstchen.py | 2 +- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index a8bbbf5872d7..82fd7b4b853b 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn -from diffusers import PaellaVQModel +from diffusers import PaellaVQModel, WuerstchenPipeline, WuerstchenPriorPipeline, DDPMScheduler from transformers import CLIPTextModel, AutoTokenizer from vqgan import VQModel @@ -30,20 +30,15 @@ # Clip Text encoder and tokenizer text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") -clip_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") +tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") # EfficientNet state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device)["effnet_state_dict"] -# Paella -state_dict = torch.load(os.path.join(model_path, "model_stage_b.pt"), map_location=device)["state_dict"] -paella_model = Paella(byt5_embd=1024).to(device) -paella_model.load_state_dict(state_dict) - # Prior -state_dict = torch.load(os.path.join(model_path, "model_v2__stage_c.pt"), map_location=device) +state_dict = torch.load(os.path.join(model_path, "model_v2_stage_c.pt"), map_location=device) prior_model = Prior(c_in=16, c=1536, c_cond=1024, c_r=64, depth=32, nhead=24).to(device) -prior_model.load_state_dict(state_dict["state_dict"]) +prior_model.load_state_dict(state_dict["ema_state_dict"]) # scheduler @@ -51,6 +46,14 @@ beta_schedule="linear", ) +# Prior pipeline +prior_pipeline = WuerstchenPriorPipeline( + prior=prior_model, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, +) + # WuerstchenPipeline( # vae=VQGan() # text_encoder=ClipTextEncoder(), diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 9b3107254765..83caf18355ca 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -169,7 +169,7 @@ def __call__( for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): # x, r, c - predicted_image_embedding = self.prior(latents, r=t, c=cond) + predicted_image_embedding = self.prior(latents, r=t / num_inference_steps, c=cond) if i + 1 == prior_timesteps_tensor.shape[0]: prev_timestep = None From a79a9ad371d95d74c51e8209d09bcd88863b5e8d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 23 Jun 2023 18:23:59 +0200 Subject: [PATCH 018/181] try to save pipeline --- scripts/convert_wuerstchen.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 82fd7b4b853b..cedc6b7539ac 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -54,6 +54,8 @@ scheduler=scheduler, ) +prior_pipeline.save_pretrained("kashif/WuerstchenPriorPipeline") + # WuerstchenPipeline( # vae=VQGan() # text_encoder=ClipTextEncoder(), From 4c28f9cd633154060f246566c494c1bf19bd668e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 23 Jun 2023 18:43:02 +0200 Subject: [PATCH 019/181] save_pretrained working --- scripts/convert_wuerstchen.py | 2 +- src/diffusers/pipelines/wuerstchen/__init__.py | 1 + src/diffusers/pipelines/wuerstchen/modules.py | 6 +++++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index cedc6b7539ac..13264328a173 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -7,10 +7,10 @@ import torch.nn as nn from diffusers import PaellaVQModel, WuerstchenPipeline, WuerstchenPriorPipeline, DDPMScheduler +from diffusers.pipelines.wuerstchen import Prior from transformers import CLIPTextModel, AutoTokenizer from vqgan import VQModel -from modules import Paella, Prior model_path = "models/" device = "cpu" diff --git a/src/diffusers/pipelines/wuerstchen/__init__.py b/src/diffusers/pipelines/wuerstchen/__init__.py index 1bbbcfa23826..2da990cd5ea4 100644 --- a/src/diffusers/pipelines/wuerstchen/__init__.py +++ b/src/diffusers/pipelines/wuerstchen/__init__.py @@ -3,3 +3,4 @@ if is_transformers_available() and is_torch_available(): from .pipeline_wuerstchen import WuerstchenPipeline, WuerstchenPriorPipeline + from .modules import Prior diff --git a/src/diffusers/pipelines/wuerstchen/modules.py b/src/diffusers/pipelines/wuerstchen/modules.py index bfd358f3d4df..d45fe0d94314 100644 --- a/src/diffusers/pipelines/wuerstchen/modules.py +++ b/src/diffusers/pipelines/wuerstchen/modules.py @@ -5,6 +5,9 @@ import torch.nn as nn from torchvision.models import efficientnet_v2_l, efficientnet_v2_s +from diffusers.configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + class LayerNorm2d(nn.LayerNorm): def __init__(self, *args, **kwargs): @@ -123,7 +126,8 @@ def forward(self, x): return self.mapper(self.backbone(x)) -class Prior(nn.Module): +class Prior(ModelMixin, ConfigMixin): + @register_to_config def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, latent_size=(12, 12), dropout=0.1): super().__init__() self.c_r = c_r From 6e51d7e89ad31d940aab91712c40bb0dc2703470 Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Sat, 24 Jun 2023 04:37:16 +0200 Subject: [PATCH 020/181] Few additions --- .../wuerstchen/pipeline_wuerstchen.py | 145 ++++++++++++++---- 1 file changed, 118 insertions(+), 27 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 83caf18355ca..bc788e6a350a 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -92,7 +92,7 @@ def __init__( scheduler: DDPMScheduler, ) -> None: super().__init__() - + self.multiple = 128 self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, @@ -111,11 +111,95 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): latents = latents * scheduler.init_noise_sigma return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + return text_encoder_hidden_states @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, + height: int = 1024, + width: int = 1024, num_inference_steps: int = 100, timesteps: List[int] = None, guidance_scale: float = 7.0, @@ -126,50 +210,51 @@ def __call__( output_type: Optional[str] = "pt", # pt only return_dict: bool = True, ): + device = self._execution_device + do_classifier_free_guidance = guidance_scale > 1.0 - clip_tokens = self.tokenizer( - [prompt] * num_images_per_prompt, - truncation=True, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ) - clip_text_embeddings = self.text_encoder(**clip_tokens).last_hidden_state - if negative_prompt is None: negative_prompt = "" - clip_tokens_uncond = self.tokenizer( - [negative_prompt] * num_images_per_prompt, - truncation=True, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ) - clip_text_embeddings_uncond = self.text_encoder(**clip_tokens_uncond).last_hidden_state + if isinstance(prompt, str): + prompt = [prompt] + elif not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - effnet_features_shape = (num_images_per_prompt, 16, 24, 24) + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif not isinstance(negative_prompt, list) and negative_prompt is not None: + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") - device = "cuda" + text_encoder_hidden_states = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt) + + latent_height = 128 * (height / 128) // (1024 / 24) + latent_width = 128 * (width / 128) // (1024 / 24) + effnet_features_shape = (num_images_per_prompt, 16, latent_height, latent_width) self.scheduler.set_timesteps(num_inference_steps, device=device) - prior_timesteps_tensor = self.scheduler.timesteps + prior_timesteps_tensor = self.scheduler.timesteps if timesteps is None else timesteps latents = self.prepare_latents( effnet_features_shape, - clip_text_embeddings.dtype, + text_encoder_hidden_states.dtype, device, generator, latents, self.scheduler, ) - cond = torch.cat([clip_text_embeddings, clip_text_embeddings_uncond]) - for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): # x, r, c - predicted_image_embedding = self.prior(latents, r=t / num_inference_steps, c=cond) + latents = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + predicted_image_embedding = self.prior(latents, r=t / num_inference_steps, c=text_encoder_hidden_states) + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) if i + 1 == prior_timesteps_tensor.shape[0]: prev_timestep = None @@ -187,7 +272,13 @@ def __call__( # normalize the latents latent = latent * 42.0 - 1.0 + if output_type not in ["pt", "np"]: + raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") + + if output_type == "np": + latents = latents.cpu().numpy() + if not return_dict: - return (latents, clip_text_embeddings, clip_text_embeddings_uncond) + return (latents, text_encoder_hidden_states) - return WuerstchenPriorPipelineOutput(latents, clip_text_embeddings, clip_text_embeddings_uncond) + return WuerstchenPriorPipelineOutput(latents, text_encoder_hidden_states) From cd5ad0441ce72250f5f6c2bb488de61815f50e74 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 26 Jun 2023 10:52:52 +0200 Subject: [PATCH 021/181] add _execution_device --- .../wuerstchen/pipeline_wuerstchen.py | 26 ++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index bc788e6a350a..2e4d4dfbc770 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -111,7 +111,7 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): latents = latents * scheduler.init_noise_sigma return latents - + def _encode_prompt( self, prompt, @@ -194,6 +194,24 @@ def _encode_prompt( return text_encoder_hidden_states + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"): + return self.device + for module in self.text_encoder.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + @torch.no_grad() def __call__( self, @@ -211,7 +229,7 @@ def __call__( return_dict: bool = True, ): device = self._execution_device - + do_classifier_free_guidance = guidance_scale > 1.0 if negative_prompt is None: @@ -227,7 +245,9 @@ def __call__( elif not isinstance(negative_prompt, list) and negative_prompt is not None: raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") - text_encoder_hidden_states = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt) + text_encoder_hidden_states = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) latent_height = 128 * (height / 128) // (1024 / 24) latent_width = 128 * (width / 128) // (1024 / 24) From 92c46dfb5c4b4285fd31b683d5432dc14cfef504 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 26 Jun 2023 11:02:12 +0200 Subject: [PATCH 022/181] shape is int --- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 2e4d4dfbc770..3bfb21d64ac2 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -249,8 +249,8 @@ def __call__( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) - latent_height = 128 * (height / 128) // (1024 / 24) - latent_width = 128 * (width / 128) // (1024 / 24) + latent_height = 128 * (height // 128) // (1024 // 24) + latent_width = 128 * (width // 128) // (1024 // 24) effnet_features_shape = (num_images_per_prompt, 16, latent_height, latent_width) self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -268,7 +268,9 @@ def __call__( for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): # x, r, c latents = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - predicted_image_embedding = self.prior(latents, r=t / num_inference_steps, c=text_encoder_hidden_states) + predicted_image_embedding = self.prior( + latents, r=t / prior_timesteps_tensor.max(), c=text_encoder_hidden_states + ) if do_classifier_free_guidance: predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) From 4665e4844e6c9837d168245079bdde417c14669a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 26 Jun 2023 13:42:47 +0200 Subject: [PATCH 023/181] fix batch size --- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 3bfb21d64ac2..fd8aa078fdf1 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -269,7 +269,9 @@ def __call__( # x, r, c latents = torch.cat([latents] * 2) if do_classifier_free_guidance else latents predicted_image_embedding = self.prior( - latents, r=t / prior_timesteps_tensor.max(), c=text_encoder_hidden_states + latents, + r=(t / prior_timesteps_tensor.max()).expand(num_images_per_prompt * 2, 1), + c=text_encoder_hidden_states, ) if do_classifier_free_guidance: From 58c98b1c30faa4c19d7ab3d6a3a846db008dcb25 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 26 Jun 2023 20:50:47 +0200 Subject: [PATCH 024/181] fix shape of ratio --- .../pipelines/wuerstchen/pipeline_wuerstchen.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index fd8aa078fdf1..f7575759f0b4 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -270,7 +270,7 @@ def __call__( latents = torch.cat([latents] * 2) if do_classifier_free_guidance else latents predicted_image_embedding = self.prior( latents, - r=(t / prior_timesteps_tensor.max()).expand(num_images_per_prompt * 2, 1), + r=(t / prior_timesteps_tensor.max()).expand(num_images_per_prompt * 2), c=text_encoder_hidden_states, ) @@ -280,17 +280,17 @@ def __call__( predicted_image_embedding_text - predicted_image_embedding_uncond ) - if i + 1 == prior_timesteps_tensor.shape[0]: - prev_timestep = None - else: - prev_timestep = prior_timesteps_tensor[i + 1] + # if i + 1 == prior_timesteps_tensor.shape[0]: + # prev_timestep = None + # else: + # prev_timestep = prior_timesteps_tensor[i + 1] latents = self.scheduler.step( predicted_image_embedding, timestep=t, sample=latents, generator=generator, - prev_timestep=prev_timestep, + # prev_timestep=prev_timestep, ).prev_sample # normalize the latents From 66cff251c25e59f01dec1e610a761825e8885c6e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 27 Jun 2023 11:24:24 +0200 Subject: [PATCH 025/181] fix shape of ratio --- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index f7575759f0b4..5971df3c8b68 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -266,11 +266,10 @@ def __call__( ) for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): - # x, r, c - latents = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ratio = t / prior_timesteps_tensor.max() predicted_image_embedding = self.prior( - latents, - r=(t / prior_timesteps_tensor.max()).expand(num_images_per_prompt * 2), + torch.cat([latents] * 2) if do_classifier_free_guidance else latents, + r=ratio.expand(num_images_per_prompt * 2) if do_classifier_free_guidance else ratio, c=text_encoder_hidden_states, ) From d06276d85c5928f3a06514da0c27eb56e0faea06 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 27 Jun 2023 11:32:11 +0200 Subject: [PATCH 026/181] fix output dataclass --- .../pipelines/wuerstchen/pipeline_wuerstchen.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 5971df3c8b68..ae83a0ee15b8 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -54,15 +54,14 @@ class WuerstchenPriorPipelineOutput(BaseOutput): Output class for WuerstchenPriorPipeline. Args: - image_embeds (`torch.FloatTensor`) - clip image embeddings for text prompt - negative_image_embeds (`List[PIL.Image.Image]` or `np.ndarray`) - clip image embeddings for unconditional tokens + image_embeds (`torch.FloatTensor` or `np.ndarray`) + Prior image embeddings for text prompt + text_embeds (`torch.FloatTensor` or `np.ndarray`) + Clip text embeddings for unconditional tokens """ image_embeds: Union[torch.FloatTensor, np.ndarray] text_embeds: Union[torch.FloatTensor, np.ndarray] - negative_text_embeds: Union[torch.FloatTensor, np.ndarray] class WuerstchenPriorPipeline(DiffusionPipeline): @@ -293,13 +292,14 @@ def __call__( ).prev_sample # normalize the latents - latent = latent * 42.0 - 1.0 + latents = latents * 42.0 - 1.0 if output_type not in ["pt", "np"]: raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") if output_type == "np": latents = latents.cpu().numpy() + text_encoder_hidden_states = text_encoder_hidden_states.cpu().numpy() if not return_dict: return (latents, text_encoder_hidden_states) From 95eb11edf6090c00c7c9f0b2fc21d9c30b61fe91 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 27 Jun 2023 11:35:23 +0200 Subject: [PATCH 027/181] tests folder --- tests/pipelines/wuerstchen/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/pipelines/wuerstchen/__init__.py diff --git a/tests/pipelines/wuerstchen/__init__.py b/tests/pipelines/wuerstchen/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 From 896624e7fb804507f804f2242e781fb4d141b5df Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 27 Jun 2023 16:10:13 +0200 Subject: [PATCH 028/181] fix formatting --- src/diffusers/models/unet_2d_blocks.py | 2 +- src/diffusers/models/unet_2d_condition.py | 2 ++ src/diffusers/pipelines/wuerstchen/__init__.py | 2 +- src/diffusers/pipelines/wuerstchen/modules.py | 1 + src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 8 ++++---- 5 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 7ee38ccf1083..c354b9b107b7 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -26,11 +26,11 @@ Downsample2D, FirDownsample2D, FirUpsample2D, + GlobalResponseResidualBlock, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D, - GlobalResponseResidualBlock, ) from .transformer_2d import Transformer2DModel diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index b08e3a501436..1f466390b7e8 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -57,6 +57,7 @@ class UNet2DConditionOutput(BaseOutput): sample: torch.FloatTensor + class LayerNorm2d(nn.LayerNorm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -64,6 +65,7 @@ def __init__(self, *args, **kwargs): def forward(self, x): return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + class PaellaUNet2dConditionalModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): _supports_gradient_checkpointing = True diff --git a/src/diffusers/pipelines/wuerstchen/__init__.py b/src/diffusers/pipelines/wuerstchen/__init__.py index 2da990cd5ea4..03c23f48d2c3 100644 --- a/src/diffusers/pipelines/wuerstchen/__init__.py +++ b/src/diffusers/pipelines/wuerstchen/__init__.py @@ -2,5 +2,5 @@ if is_transformers_available() and is_torch_available(): - from .pipeline_wuerstchen import WuerstchenPipeline, WuerstchenPriorPipeline from .modules import Prior + from .pipeline_wuerstchen import WuerstchenPipeline, WuerstchenPriorPipeline diff --git a/src/diffusers/pipelines/wuerstchen/modules.py b/src/diffusers/pipelines/wuerstchen/modules.py index d45fe0d94314..ece1fbdd527e 100644 --- a/src/diffusers/pipelines/wuerstchen/modules.py +++ b/src/diffusers/pipelines/wuerstchen/modules.py @@ -6,6 +6,7 @@ from torchvision.models import efficientnet_v2_l, efficientnet_v2_s from diffusers.configuration_utils import ConfigMixin, register_to_config + from ...models.modeling_utils import ModelMixin diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index ae83a0ee15b8..c76f9b3f3fc9 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -17,14 +17,14 @@ import numpy as np import torch -from transformers import CLIPTokenizer, CLIPTextModel +from transformers import CLIPTextModel, CLIPTokenizer from ...models import PaellaVQModel -from ...utils import is_accelerate_available, logging, BaseOutput, randn_tensor -from ..pipeline_utils import DiffusionPipeline from ...schedulers import DDPMScheduler +from ...utils import BaseOutput, logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .modules import DiffNeXt, Prior -from .modules import DiffNeXt, Prior, EfficientNetEncoder logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 1ed7a58607988edee27b1ae45463c44479109475 Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Wed, 28 Jun 2023 04:25:51 +0200 Subject: [PATCH 029/181] fix float16 + started with generator --- scripts/vqgan.py | 144 +++++++++++++++ scripts/wuerstchen_pipeline_test.py | 10 ++ src/diffusers/pipelines/wuerstchen/modules.py | 2 +- .../wuerstchen/pipeline_wuerstchen.py | 168 ++++++++++++++++-- 4 files changed, 312 insertions(+), 12 deletions(-) create mode 100644 scripts/vqgan.py create mode 100644 scripts/wuerstchen_pipeline_test.py diff --git a/scripts/vqgan.py b/scripts/vqgan.py new file mode 100644 index 000000000000..935023e1ff6e --- /dev/null +++ b/scripts/vqgan.py @@ -0,0 +1,144 @@ +import torch +from torch import nn +import numpy as np +import math +from tqdm import tqdm +import time +from torchtools.nn import VectorQuantize + +class ResBlock(nn.Module): + def __init__(self, c, c_hidden): + super().__init__() + # depthwise/attention + self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.depthwise = nn.Sequential( + nn.ReplicationPad2d(1), + nn.Conv2d(c, c, kernel_size=3, groups=c) + ) + + # channelwise + self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c, c_hidden), + nn.GELU(), + nn.Linear(c_hidden, c), + ) + + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) + + # Init weights + def _basic_init(module): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + + def _norm(self, x, norm): + return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + def forward(self, x): + mods = self.gammas + + x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] + x = x + self.depthwise(x_temp) * mods[2] + + x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] + x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] + + return x + +class VQModel(nn.Module): + def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, scale_factor=0.3764): # 1.0 + super().__init__() + self.c_latent = c_latent + self.scale_factor = scale_factor + c_levels = [c_hidden//(2**i) for i in reversed(range(levels))] + + # Encoder blocks + self.in_block = nn.Sequential( + nn.PixelUnshuffle(2), + nn.Conv2d(3*4, c_levels[0], kernel_size=1) + ) + down_blocks = [] + for i in range(levels): + if i > 0: + down_blocks.append(nn.Conv2d(c_levels[i-1], c_levels[i], kernel_size=4, stride=2, padding=1)) + block = ResBlock(c_levels[i], c_levels[i]*4) + down_blocks.append(block) + down_blocks.append(nn.Sequential( + nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 + )) + self.down_blocks = nn.Sequential(*down_blocks) + self.down_blocks[0] + + self.codebook_size = codebook_size + self.vquantizer = VectorQuantize(c_latent, k=codebook_size) + + # Decoder blocks + up_blocks = [nn.Sequential( + nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) + )] + for i in range(levels): + for j in range(bottleneck_blocks if i == 0 else 1): + block = ResBlock(c_levels[levels-1-i], c_levels[levels-1-i]*4) + up_blocks.append(block) + if i < levels-1: + up_blocks.append(nn.ConvTranspose2d(c_levels[levels-1-i], c_levels[levels-2-i], kernel_size=4, stride=2, padding=1)) + self.up_blocks = nn.Sequential(*up_blocks) + self.out_block = nn.Sequential( + nn.Conv2d(c_levels[0], 3*4, kernel_size=1), + nn.PixelShuffle(2), + ) + + def encode(self, x): + x = self.in_block(x) + x = self.down_blocks(x) + qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) + return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 + + def decode(self, x): + x = x * self.scale_factor + x = self.up_blocks(x) + x = self.out_block(x) + return x + + def decode_indices(self, x): + x = self.vquantizer.idx2vq(x, dim=1) + x = self.up_blocks(x) + x = self.out_block(x) + return x + + def forward(self, x, quantize=False): + qe, x, _, vq_loss = self.encode(x, quantize) + x = self.decode(qe) + return x, vq_loss + +class Discriminator(nn.Module): + def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6): + super().__init__() + d = max(depth - 3, 3) + layers = [ + nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), + nn.LeakyReLU(0.2), + ] + for i in range(depth - 1): + c_in = c_hidden // (2 ** max((d - i), 0)) + c_out = c_hidden // (2 ** max((d - 1 - i), 0)) + layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) + layers.append(nn.InstanceNorm2d(c_out)) + layers.append(nn.LeakyReLU(0.2)) + self.encoder = nn.Sequential(*layers) + self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) + self.logits = nn.Sigmoid() + + def forward(self, x, cond=None): + x = self.encoder(x) + if cond is not None: + cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) + x = torch.cat([x, cond], dim=1) + x = self.shuffle(x) + x = self.logits(x) + return x \ No newline at end of file diff --git a/scripts/wuerstchen_pipeline_test.py b/scripts/wuerstchen_pipeline_test.py new file mode 100644 index 000000000000..2fb4ec1da319 --- /dev/null +++ b/scripts/wuerstchen_pipeline_test.py @@ -0,0 +1,10 @@ +import torch +from diffusers import WuerstchenPriorPipeline + +prior_pipeline = WuerstchenPriorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\kashif\\WuerstchenPriorPipeline", torch_dtype=torch.float16) +prior_pipeline = prior_pipeline.to("cuda") + +generator_pipeline = WuerstchenPriorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\kashif\\WuerstchenPriorPipeline", torch_dtype=torch.float16) +generator_pipeline = generator_pipeline.to("cuda") + +generator_output = generator_pipeline("An image of a squirrel in Picasso style") \ No newline at end of file diff --git a/src/diffusers/pipelines/wuerstchen/modules.py b/src/diffusers/pipelines/wuerstchen/modules.py index ece1fbdd527e..c590d83c3199 100644 --- a/src/diffusers/pipelines/wuerstchen/modules.py +++ b/src/diffusers/pipelines/wuerstchen/modules.py @@ -177,7 +177,7 @@ def gen_r_embedding(self, r, max_positions=10000): emb = torch.cat([emb.sin(), emb.cos()], dim=1) if self.c_r % 2 == 1: # zero pad emb = nn.functional.pad(emb, (0, 1), mode="constant") - return emb + return emb.to(dtype=r.dtype) def forward(self, x, r, c): x_in = x diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index c76f9b3f3fc9..10adf8475c89 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -64,6 +64,21 @@ class WuerstchenPriorPipelineOutput(BaseOutput): text_embeds: Union[torch.FloatTensor, np.ndarray] +@dataclass +class WuerstchenGeneratorPipelineOutput(BaseOutput): + """ + Output class for WuerstchenPriorPipeline. + + Args: + image_embeds (`torch.FloatTensor` or `np.ndarray`) + Prior image embeddings for text prompt + text_embeds (`torch.FloatTensor` or `np.ndarray`) + Clip text embeddings for unconditional tokens + """ + + image_embeds: Union[torch.FloatTensor, np.ndarray] + + class WuerstchenPriorPipeline(DiffusionPipeline): """ Pipeline for generating image prior for Wuerstchen. @@ -217,7 +232,7 @@ def __call__( prompt: Union[str, List[str]] = None, height: int = 1024, width: int = 1024, - num_inference_steps: int = 100, + num_inference_steps: int = 30, timesteps: List[int] = None, guidance_scale: float = 7.0, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -247,7 +262,7 @@ def __call__( text_encoder_hidden_states = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) - + dtype = text_encoder_hidden_states.dtype latent_height = 128 * (height // 128) // (1024 // 24) latent_width = 128 * (width // 128) // (1024 // 24) effnet_features_shape = (num_images_per_prompt, 16, latent_height, latent_width) @@ -257,15 +272,18 @@ def __call__( latents = self.prepare_latents( effnet_features_shape, - text_encoder_hidden_states.dtype, + dtype, device, generator, latents, self.scheduler, ) - + print(prior_timesteps_tensor) for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): - ratio = t / prior_timesteps_tensor.max() + ratio = (t / prior_timesteps_tensor.max()).to(dtype=dtype) + # print(torch.cat([latents] * 2).shape, latents.dtype) + # print(ratio.expand(num_images_per_prompt * 2).shape, ratio.dtype) + # print(text_encoder_hidden_states.shape, text_encoder_hidden_states.dtype) predicted_image_embedding = self.prior( torch.cat([latents] * 2) if do_classifier_free_guidance else latents, r=ratio.expand(num_images_per_prompt * 2) if do_classifier_free_guidance else ratio, @@ -278,17 +296,11 @@ def __call__( predicted_image_embedding_text - predicted_image_embedding_uncond ) - # if i + 1 == prior_timesteps_tensor.shape[0]: - # prev_timestep = None - # else: - # prev_timestep = prior_timesteps_tensor[i + 1] - latents = self.scheduler.step( predicted_image_embedding, timestep=t, sample=latents, generator=generator, - # prev_timestep=prev_timestep, ).prev_sample # normalize the latents @@ -305,3 +317,137 @@ def __call__( return (latents, text_encoder_hidden_states) return WuerstchenPriorPipelineOutput(latents, text_encoder_hidden_states) + + +class WuerstchenGeneratorPipeline(DiffusionPipeline): + """ + Pipeline for generating image prior for Wuerstchen. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`Prior`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`DDPMScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + """ + + def __init__( + self, + generator: DiffNeXt, + scheduler: DDPMScheduler, + ) -> None: + super().__init__() + self.multiple = 128 + self.register_modules( + generator=generator, + scheduler=scheduler, + ) + self.register_to_config() + + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"): + return self.device + for module in self.text_encoder.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @torch.no_grad() + def __call__( + self, + predicted_image_embeddings: torch.Tensor, + text_encoder_hidden_states: torch.Tensor, + num_inference_steps: int = 30, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pt", # pt only + return_dict: bool = True, + ): + device = self._execution_device + + do_classifier_free_guidance = guidance_scale > 1.0 + + check_inputs(text_encoder_hidden_states, predicted_image_embeddings, do_classifier_free_guidance) + + dtype = text_encoder_hidden_states.dtype + latent_height = predicted_image_embeddings.size(2) * (1024 // 24) + latent_width = predicted_image_embeddings.size(2) * (1024 // 24) + effnet_features_shape = (num_images_per_prompt, 4, latent_height, latent_width) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + generator_timesteps_tensor = self.scheduler.timesteps if timesteps is None else timesteps + + latents = self.prepare_latents( + effnet_features_shape, + dtype, + device, + generator, + latents, + self.scheduler, + ) + print(generator_timesteps_tensor) + for i, t in enumerate(self.progress_bar(generator_timesteps_tensor)): + ratio = (t / generator_timesteps_tensor.max()).to(dtype=dtype) + # print(torch.cat([latents] * 2).shape, latents.dtype) + # print(ratio.expand(num_images_per_prompt * 2).shape, ratio.dtype) + # print(text_encoder_hidden_states.shape, text_encoder_hidden_states.dtype) + predicted_image_embedding = self.generator( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents, + r=ratio.expand(num_images_per_prompt * 2) if do_classifier_free_guidance else ratio, + c=text_encoder_hidden_states, + ) + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + latents = self.scheduler.step( + predicted_image_embedding, + timestep=t, + sample=latents, + generator=generator, + ).prev_sample + + if output_type not in ["pt", "np"]: + raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") + + if output_type == "np": + latents = latents.cpu().numpy() + + if not return_dict: + return (latents, text_encoder_hidden_states) + + return WuerstchenGeneratorPipelineOutput(latents) \ No newline at end of file From e809fd78c4bfc2484b44987e56fb20ca18563eb3 Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Thu, 29 Jun 2023 05:00:11 +0200 Subject: [PATCH 030/181] Update pipeline_wuerstchen.py --- .../wuerstchen/pipeline_wuerstchen.py | 67 ++++++++++++------- 1 file changed, 42 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 10adf8475c89..b701eedc2bb2 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -42,6 +42,10 @@ ``` """ +default_inference_steps = { + 2/3: 20, + 0.0: 10 + } class WuerstchenPipeline(DiffusionPipeline): unet: DiffNeXt @@ -226,6 +230,35 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device + @torch.no_grad() + def inference_loop(self, latents, steps, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator): + for t in self.progress_bar(steps): + # print(torch.cat([latents] * 2).shape, latents.dtype) + # print(ratio.expand(num_images_per_prompt * 2).shape, ratio.dtype) + # print(text_encoder_hidden_states.shape, text_encoder_hidden_states.dtype) + predicted_image_embedding = self.prior( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents, + r=t.expand(latents.size(0) * 2) if do_classifier_free_guidance else t, + c=text_encoder_hidden_states, + ) + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + print(t) + timestep = (t * 999).cpu().int() + print(timestep) + latents = self.scheduler.step( + predicted_image_embedding, + timestep=timestep, + sample=latents, + generator=generator, + ).prev_sample + + return latents + @torch.no_grad() def __call__( self, @@ -233,7 +266,7 @@ def __call__( height: int = 1024, width: int = 1024, num_inference_steps: int = 30, - timesteps: List[int] = None, + inference_steps: dict = None, guidance_scale: float = 7.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -246,6 +279,9 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 + if inference_steps is None: + inference_steps = default_inference_steps + if negative_prompt is None: negative_prompt = "" @@ -268,7 +304,6 @@ def __call__( effnet_features_shape = (num_images_per_prompt, 16, latent_height, latent_width) self.scheduler.set_timesteps(num_inference_steps, device=device) - prior_timesteps_tensor = self.scheduler.timesteps if timesteps is None else timesteps latents = self.prepare_latents( effnet_features_shape, @@ -278,30 +313,12 @@ def __call__( latents, self.scheduler, ) - print(prior_timesteps_tensor) - for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): - ratio = (t / prior_timesteps_tensor.max()).to(dtype=dtype) - # print(torch.cat([latents] * 2).shape, latents.dtype) - # print(ratio.expand(num_images_per_prompt * 2).shape, ratio.dtype) - # print(text_encoder_hidden_states.shape, text_encoder_hidden_states.dtype) - predicted_image_embedding = self.prior( - torch.cat([latents] * 2) if do_classifier_free_guidance else latents, - r=ratio.expand(num_images_per_prompt * 2) if do_classifier_free_guidance else ratio, - c=text_encoder_hidden_states, - ) - - if do_classifier_free_guidance: - predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) - predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( - predicted_image_embedding_text - predicted_image_embedding_uncond - ) - latents = self.scheduler.step( - predicted_image_embedding, - timestep=t, - sample=latents, - generator=generator, - ).prev_sample + t_start = 1.0 + for t_end, steps in inference_steps.items(): + steps = torch.linspace(t_start, t_end, steps, dtype=dtype, device=device) + latents = self.inference_loop(latents, steps, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator) + t_start = t_end # normalize the latents latents = latents * 42.0 - 1.0 From 0ad3f794ad8a46f56b6cca42c8cb9e59b8e38608 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 30 Jun 2023 09:50:23 +0200 Subject: [PATCH 031/181] removed vqgan code --- scripts/vqgan.py | 144 ------------------ .../wuerstchen/pipeline_wuerstchen.py | 24 +-- 2 files changed, 13 insertions(+), 155 deletions(-) delete mode 100644 scripts/vqgan.py diff --git a/scripts/vqgan.py b/scripts/vqgan.py deleted file mode 100644 index 935023e1ff6e..000000000000 --- a/scripts/vqgan.py +++ /dev/null @@ -1,144 +0,0 @@ -import torch -from torch import nn -import numpy as np -import math -from tqdm import tqdm -import time -from torchtools.nn import VectorQuantize - -class ResBlock(nn.Module): - def __init__(self, c, c_hidden): - super().__init__() - # depthwise/attention - self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) - self.depthwise = nn.Sequential( - nn.ReplicationPad2d(1), - nn.Conv2d(c, c, kernel_size=3, groups=c) - ) - - # channelwise - self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) - self.channelwise = nn.Sequential( - nn.Linear(c, c_hidden), - nn.GELU(), - nn.Linear(c_hidden, c), - ) - - self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) - - # Init weights - def _basic_init(module): - if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - self.apply(_basic_init) - - - def _norm(self, x, norm): - return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - - def forward(self, x): - mods = self.gammas - - x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] - x = x + self.depthwise(x_temp) * mods[2] - - x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] - x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] - - return x - -class VQModel(nn.Module): - def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, scale_factor=0.3764): # 1.0 - super().__init__() - self.c_latent = c_latent - self.scale_factor = scale_factor - c_levels = [c_hidden//(2**i) for i in reversed(range(levels))] - - # Encoder blocks - self.in_block = nn.Sequential( - nn.PixelUnshuffle(2), - nn.Conv2d(3*4, c_levels[0], kernel_size=1) - ) - down_blocks = [] - for i in range(levels): - if i > 0: - down_blocks.append(nn.Conv2d(c_levels[i-1], c_levels[i], kernel_size=4, stride=2, padding=1)) - block = ResBlock(c_levels[i], c_levels[i]*4) - down_blocks.append(block) - down_blocks.append(nn.Sequential( - nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), - nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 - )) - self.down_blocks = nn.Sequential(*down_blocks) - self.down_blocks[0] - - self.codebook_size = codebook_size - self.vquantizer = VectorQuantize(c_latent, k=codebook_size) - - # Decoder blocks - up_blocks = [nn.Sequential( - nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) - )] - for i in range(levels): - for j in range(bottleneck_blocks if i == 0 else 1): - block = ResBlock(c_levels[levels-1-i], c_levels[levels-1-i]*4) - up_blocks.append(block) - if i < levels-1: - up_blocks.append(nn.ConvTranspose2d(c_levels[levels-1-i], c_levels[levels-2-i], kernel_size=4, stride=2, padding=1)) - self.up_blocks = nn.Sequential(*up_blocks) - self.out_block = nn.Sequential( - nn.Conv2d(c_levels[0], 3*4, kernel_size=1), - nn.PixelShuffle(2), - ) - - def encode(self, x): - x = self.in_block(x) - x = self.down_blocks(x) - qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) - return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 - - def decode(self, x): - x = x * self.scale_factor - x = self.up_blocks(x) - x = self.out_block(x) - return x - - def decode_indices(self, x): - x = self.vquantizer.idx2vq(x, dim=1) - x = self.up_blocks(x) - x = self.out_block(x) - return x - - def forward(self, x, quantize=False): - qe, x, _, vq_loss = self.encode(x, quantize) - x = self.decode(qe) - return x, vq_loss - -class Discriminator(nn.Module): - def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6): - super().__init__() - d = max(depth - 3, 3) - layers = [ - nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), - nn.LeakyReLU(0.2), - ] - for i in range(depth - 1): - c_in = c_hidden // (2 ** max((d - i), 0)) - c_out = c_hidden // (2 ** max((d - 1 - i), 0)) - layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) - layers.append(nn.InstanceNorm2d(c_out)) - layers.append(nn.LeakyReLU(0.2)) - self.encoder = nn.Sequential(*layers) - self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) - self.logits = nn.Sigmoid() - - def forward(self, x, cond=None): - x = self.encoder(x) - if cond is not None: - cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) - x = torch.cat([x, cond], dim=1) - x = self.shuffle(x) - x = self.logits(x) - return x \ No newline at end of file diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index b701eedc2bb2..339725a98eba 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -42,10 +42,8 @@ ``` """ -default_inference_steps = { - 2/3: 20, - 0.0: 10 - } +default_inference_steps = {2 / 3: 20, 0.0: 10} + class WuerstchenPipeline(DiffusionPipeline): unet: DiffNeXt @@ -231,7 +229,9 @@ def _execution_device(self): return self.device @torch.no_grad() - def inference_loop(self, latents, steps, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator): + def inference_loop( + self, latents, steps, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator + ): for t in self.progress_bar(steps): # print(torch.cat([latents] * 2).shape, latents.dtype) # print(ratio.expand(num_images_per_prompt * 2).shape, ratio.dtype) @@ -247,12 +247,12 @@ def inference_loop(self, latents, steps, text_encoder_hidden_states, do_classifi predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( predicted_image_embedding_text - predicted_image_embedding_uncond ) - print(t) + # print(t) timestep = (t * 999).cpu().int() - print(timestep) + # print(timestep) latents = self.scheduler.step( predicted_image_embedding, - timestep=timestep, + timestep=timestep - 1, sample=latents, generator=generator, ).prev_sample @@ -317,7 +317,9 @@ def __call__( t_start = 1.0 for t_end, steps in inference_steps.items(): steps = torch.linspace(t_start, t_end, steps, dtype=dtype, device=device) - latents = self.inference_loop(latents, steps, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator) + latents = self.inference_loop( + latents, steps, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator + ) t_start = t_end # normalize the latents @@ -433,7 +435,7 @@ def __call__( latents, self.scheduler, ) - print(generator_timesteps_tensor) + # print(generator_timesteps_tensor) for i, t in enumerate(self.progress_bar(generator_timesteps_tensor)): ratio = (t / generator_timesteps_tensor.max()).to(dtype=dtype) # print(torch.cat([latents] * 2).shape, latents.dtype) @@ -467,4 +469,4 @@ def __call__( if not return_dict: return (latents, text_encoder_hidden_states) - return WuerstchenGeneratorPipelineOutput(latents) \ No newline at end of file + return WuerstchenGeneratorPipelineOutput(latents) From 2edfc4896f70f91395988701695335eb8caf0f53 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 30 Jun 2023 12:35:12 +0200 Subject: [PATCH 032/181] add WuerstchenGeneratorPipeline --- scripts/convert_wuerstchen.py | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 13264328a173..4f0107a51b7b 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -1,16 +1,18 @@ -import argparse -import inspect import os -import numpy as np import torch -import torch.nn as nn +from transformers import AutoTokenizer, CLIPTextModel +from vqgan import VQModel +from modules import DiffNeXt, EfficientNetEncoder -from diffusers import PaellaVQModel, WuerstchenPipeline, WuerstchenPriorPipeline, DDPMScheduler +from diffusers import ( + DDPMScheduler, + PaellaVQModel, + WuerstchenPriorPipeline, + WuerstchenGeneratorPipeline, +) from diffusers.pipelines.wuerstchen import Prior -from transformers import CLIPTextModel, AutoTokenizer -from vqgan import VQModel model_path = "models/" device = "cpu" @@ -33,7 +35,13 @@ tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") # EfficientNet -state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device)["effnet_state_dict"] +state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device) +efficient_net = EfficientNetEncoder() +efficient_net.load_state_dict(state_dict["effnet_state_dict"]) + +# Generator +generator = DiffNeXt() +generator.load_state_dict(state_dict["state_dict"]) # Prior state_dict = torch.load(os.path.join(model_path, "model_v2_stage_c.pt"), map_location=device) @@ -44,6 +52,8 @@ # scheduler scheduler = DDPMScheduler( beta_schedule="linear", + beta_start=0.0001, + beta_end=0.02, ) # Prior pipeline @@ -56,6 +66,14 @@ prior_pipeline.save_pretrained("kashif/WuerstchenPriorPipeline") +generator_pipeline = WuerstchenGeneratorPipeline( + vae=vqmodel, + generator=generator, + efficient_net=efficient_net, + scheduler=scheduler, +) + + # WuerstchenPipeline( # vae=VQGan() # text_encoder=ClipTextEncoder(), From 624c6d9eb611a443c7ebd86bfa4f0c65ef7d634e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 30 Jun 2023 13:19:29 +0200 Subject: [PATCH 033/181] fix WuerstchenGeneratorPipeline --- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 339725a98eba..3abbbc5d53c8 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -23,7 +23,7 @@ from ...schedulers import DDPMScheduler from ...utils import BaseOutput, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline -from .modules import DiffNeXt, Prior +from .modules import DiffNeXt, Prior, EfficientNetEncoder logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -361,12 +361,16 @@ def __init__( self, generator: DiffNeXt, scheduler: DDPMScheduler, + vqgan: PaellaVQModel, + efficient_net: EfficientNetEncoder, ) -> None: super().__init__() self.multiple = 128 self.register_modules( generator=generator, scheduler=scheduler, + vqgan=vqgan, + efficient_net=efficient_net, ) self.register_to_config() From 0d3c3f39ab93cd5fd977cf3698c4d0d1630f2209 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 30 Jun 2023 13:22:04 +0200 Subject: [PATCH 034/181] fix docstrings --- .../pipelines/wuerstchen/pipeline_wuerstchen.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 3abbbc5d53c8..b32e0297870e 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -340,19 +340,18 @@ def __call__( class WuerstchenGeneratorPipeline(DiffusionPipeline): """ - Pipeline for generating image prior for Wuerstchen. + Pipeline for generating images from the Wuerstchen model. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: - prior ([`Prior`]): - The canonical unCLIP prior to approximate the image embedding from the text embedding. - text_encoder ([`CLIPTextModelWithProjection`]): - Frozen text-encoder. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + generator ([`DiffNeXt`]): + The DiffNeXt unet generator. + vqgan ([`PaellaVQModel`]): + The VQGAN model. + efficient_net ([`EfficientNetEncoder`]): + The EfficientNet encoder. scheduler ([`DDPMScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. """ From 38fa6d14bce81a8a75437f10dce7fb6ae4dc79f8 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 30 Jun 2023 13:46:56 +0200 Subject: [PATCH 035/181] fix imports --- src/diffusers/__init__.py | 2 +- src/diffusers/pipelines/__init__.py | 2 +- src/diffusers/pipelines/wuerstchen/__init__.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 26874f53d225..a02d9467b94e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -175,7 +175,7 @@ VersatileDiffusionPipeline, VersatileDiffusionTextToImagePipeline, VQDiffusionPipeline, - WuerstchenPipeline, + WuerstchenGeneratorPipeline, WuerstchenPriorPipeline, ) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 1dee303b5437..e13d0254d962 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -99,7 +99,7 @@ VersatileDiffusionTextToImagePipeline, ) from .vq_diffusion import VQDiffusionPipeline - from .wuerstchen import WuerstchenPipeline, WuerstchenPriorPipeline + from .wuerstchen import WuerstchenGeneratorPipeline, WuerstchenPriorPipeline try: if not is_onnx_available(): diff --git a/src/diffusers/pipelines/wuerstchen/__init__.py b/src/diffusers/pipelines/wuerstchen/__init__.py index 03c23f48d2c3..c39927a1772a 100644 --- a/src/diffusers/pipelines/wuerstchen/__init__.py +++ b/src/diffusers/pipelines/wuerstchen/__init__.py @@ -3,4 +3,4 @@ if is_transformers_available() and is_torch_available(): from .modules import Prior - from .pipeline_wuerstchen import WuerstchenPipeline, WuerstchenPriorPipeline + from .pipeline_wuerstchen import WuerstchenGeneratorPipeline, WuerstchenPriorPipeline From ea2c64ec31e8387c52a9ce2ef6e0411525bbd347 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 30 Jun 2023 14:03:33 +0200 Subject: [PATCH 036/181] convert generator pipeline --- scripts/convert_wuerstchen.py | 3 ++- src/diffusers/pipelines/wuerstchen/modules.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 4f0107a51b7b..454b166977c8 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -67,11 +67,12 @@ prior_pipeline.save_pretrained("kashif/WuerstchenPriorPipeline") generator_pipeline = WuerstchenGeneratorPipeline( - vae=vqmodel, + vqgan=vqmodel, generator=generator, efficient_net=efficient_net, scheduler=scheduler, ) +generator_pipeline.save_pretrained("kashif/WuerstchenGeneratorPipeline") # WuerstchenPipeline( diff --git a/src/diffusers/pipelines/wuerstchen/modules.py b/src/diffusers/pipelines/wuerstchen/modules.py index c590d83c3199..2bb4848a151d 100644 --- a/src/diffusers/pipelines/wuerstchen/modules.py +++ b/src/diffusers/pipelines/wuerstchen/modules.py @@ -110,7 +110,8 @@ def forward(self, x, kv): return x -class EfficientNetEncoder(nn.Module): +class EfficientNetEncoder(ModelMixin, ConfigMixin): + @register_to_config def __init__(self, c_latent=16, effnet="efficientnet_v2_s"): super().__init__() if effnet == "efficientnet_v2_s": From 96cb4dec5bd6b76484804b73e691fed17ef96d9e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 30 Jun 2023 14:14:10 +0200 Subject: [PATCH 037/181] fix convert --- scripts/convert_wuerstchen.py | 3 +-- src/diffusers/pipelines/wuerstchen/__init__.py | 2 +- src/diffusers/pipelines/wuerstchen/modules.py | 3 ++- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 454b166977c8..930cdb3b4ee9 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -3,7 +3,6 @@ import torch from transformers import AutoTokenizer, CLIPTextModel from vqgan import VQModel -from modules import DiffNeXt, EfficientNetEncoder from diffusers import ( DDPMScheduler, @@ -11,7 +10,7 @@ WuerstchenPriorPipeline, WuerstchenGeneratorPipeline, ) -from diffusers.pipelines.wuerstchen import Prior +from diffusers.pipelines.wuerstchen import Prior, DiffNeXt, EfficientNetEncoder model_path = "models/" diff --git a/src/diffusers/pipelines/wuerstchen/__init__.py b/src/diffusers/pipelines/wuerstchen/__init__.py index c39927a1772a..0b312ce9a867 100644 --- a/src/diffusers/pipelines/wuerstchen/__init__.py +++ b/src/diffusers/pipelines/wuerstchen/__init__.py @@ -2,5 +2,5 @@ if is_transformers_available() and is_torch_available(): - from .modules import Prior + from .modules import Prior, DiffNeXt, EfficientNetEncoder from .pipeline_wuerstchen import WuerstchenGeneratorPipeline, WuerstchenPriorPipeline diff --git a/src/diffusers/pipelines/wuerstchen/modules.py b/src/diffusers/pipelines/wuerstchen/modules.py index 2bb4848a151d..c4e507aac116 100644 --- a/src/diffusers/pipelines/wuerstchen/modules.py +++ b/src/diffusers/pipelines/wuerstchen/modules.py @@ -201,7 +201,8 @@ def update_weights_ema(self, src_model, beta=0.999): self_params.data = self_params.data * beta + src_params.data * (1 - beta) -class DiffNeXt(nn.Module): +class DiffNeXt(ModelMixin, ConfigMixin): + @register_to_config def __init__( self, c_in=4, From 2c6d0dd0407ffed7219024cabb5472bd8610b19c Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Sat, 1 Jul 2023 17:12:38 +0200 Subject: [PATCH 038/181] Work on Generator Pipeline. WIP --- scripts/wuerstchen_pipeline_test.py | 8 +- src/diffusers/pipelines/wuerstchen/modules.py | 2 +- .../wuerstchen/pipeline_wuerstchen.py | 94 +++++++++++++------ 3 files changed, 68 insertions(+), 36 deletions(-) diff --git a/scripts/wuerstchen_pipeline_test.py b/scripts/wuerstchen_pipeline_test.py index 2fb4ec1da319..1a47b9c1d89a 100644 --- a/scripts/wuerstchen_pipeline_test.py +++ b/scripts/wuerstchen_pipeline_test.py @@ -1,10 +1,10 @@ import torch -from diffusers import WuerstchenPriorPipeline +from diffusers import WuerstchenPriorPipeline, WuerstchenGeneratorPipeline prior_pipeline = WuerstchenPriorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\kashif\\WuerstchenPriorPipeline", torch_dtype=torch.float16) +generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\kashif\\WuerstchenGeneratorPipeline", torch_dtype=torch.float16) prior_pipeline = prior_pipeline.to("cuda") - -generator_pipeline = WuerstchenPriorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\kashif\\WuerstchenPriorPipeline", torch_dtype=torch.float16) generator_pipeline = generator_pipeline.to("cuda") -generator_output = generator_pipeline("An image of a squirrel in Picasso style") \ No newline at end of file +prior_output = prior_pipeline("An image of a squirrel in Picasso style") +generator_output = generator_pipeline(prior_output.image_embeds, prior_output.text_embeds) diff --git a/src/diffusers/pipelines/wuerstchen/modules.py b/src/diffusers/pipelines/wuerstchen/modules.py index c4e507aac116..211fa92ccb1a 100644 --- a/src/diffusers/pipelines/wuerstchen/modules.py +++ b/src/diffusers/pipelines/wuerstchen/modules.py @@ -328,7 +328,7 @@ def gen_r_embedding(self, r, max_positions=10000): emb = torch.cat([emb.sin(), emb.cos()], dim=1) if self.c_r % 2 == 1: # zero pad emb = nn.functional.pad(emb, (0, 1), mode="constant") - return emb + return emb.to(dtype=r.dtype) def gen_c_embeddings(self, clip): clip = self.clip_mapper(clip) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index b32e0297870e..9c87cca64c1b 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -265,7 +265,6 @@ def __call__( prompt: Union[str, List[str]] = None, height: int = 1024, width: int = 1024, - num_inference_steps: int = 30, inference_steps: dict = None, guidance_scale: float = 7.0, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -303,7 +302,8 @@ def __call__( latent_width = 128 * (width // 128) // (1024 // 24) effnet_features_shape = (num_images_per_prompt, 16, latent_height, latent_width) - self.scheduler.set_timesteps(num_inference_steps, device=device) + total_num_inference_steps = sum(inference_steps.values()) + self.scheduler.set_timesteps(total_num_inference_steps, device=device) latents = self.prepare_latents( effnet_features_shape, @@ -402,13 +402,58 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device + def check_inputs(self, predicted_image_embeddings, text_encoder_hidden_states, do_classifier_free_guidance, device): + if not isinstance(text_encoder_hidden_states, torch.Tensor): + raise TypeError(f"'text_encoder_hidden_states' must be of type 'torch.Tensor', but got {type(predicted_image_embeddings)}.") + if isinstance(predicted_image_embeddings, np.ndarray): + predicted_image_embeddings = torch.Tensor(predicted_image_embeddings, device=device).to(dtype=text_encoder_hidden_states.dtype) + if not isinstance(predicted_image_embeddings, torch.Tensor): + raise TypeError(f"'predicted_image_embeddings' must be of type 'torch.Tensor' or 'np.array', but got {type(predicted_image_embeddings)}.") + + if do_classifier_free_guidance: + assert predicted_image_embeddings.size(0) == text_encoder_hidden_states.size(0) // 2, f"'text_encoder_hidden_states' must be double the size of 'predicted_image_embeddings' in the first dimension, but {predicted_image_embeddings.size(0)} != {text_encoder_hidden_states.size(0)}." + else: + assert predicted_image_embeddings.size(0) == text_encoder_hidden_states.size(0), f"'text_encoder_hidden_states' must be the size of 'predicted_image_embeddings' in the first dimension, but {predicted_image_embeddings.size(0)} != {text_encoder_hidden_states.size(0)}." + + return predicted_image_embeddings, text_encoder_hidden_states + + @torch.no_grad() + def inference_loop( + self, latents, steps, predicted_effnet_latents, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator + ): + for t in self.progress_bar(steps): + print(torch.cat([latents] * 2).shape, latents.dtype) + print(t.expand(latents.size(0) * 2).shape, t.dtype) + print(text_encoder_hidden_states.shape, text_encoder_hidden_states.dtype) + predicted_image_embedding = self.generator( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents, + r=t.expand(latents.size(0) * 2) if do_classifier_free_guidance else t, + effnet=predicted_effnet_latents, clip=text_encoder_hidden_states, + ) + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + # print(t) + timestep = (t * 999).cpu().int() + # print(timestep) + latents = self.scheduler.step( + predicted_image_embedding, + timestep=timestep - 1, + sample=latents, + generator=generator, + ).prev_sample + + return latents + @torch.no_grad() def __call__( self, predicted_image_embeddings: torch.Tensor, text_encoder_hidden_states: torch.Tensor, - num_inference_steps: int = 30, - timesteps: List[int] = None, + inference_steps: dict = None, guidance_scale: float = 7.0, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -420,15 +465,18 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 - check_inputs(text_encoder_hidden_states, predicted_image_embeddings, do_classifier_free_guidance) + if inference_steps is None: + inference_steps = default_inference_steps + + predicted_image_embeddings, text_encoder_hidden_states = self.check_inputs(predicted_image_embeddings, text_encoder_hidden_states, do_classifier_free_guidance, device) dtype = text_encoder_hidden_states.dtype - latent_height = predicted_image_embeddings.size(2) * (1024 // 24) - latent_width = predicted_image_embeddings.size(2) * (1024 // 24) + latent_height = int(predicted_image_embeddings.size(2) * (256 / 24)) + latent_width = int(predicted_image_embeddings.size(3) * (256 / 24)) effnet_features_shape = (num_images_per_prompt, 4, latent_height, latent_width) - self.scheduler.set_timesteps(num_inference_steps, device=device) - generator_timesteps_tensor = self.scheduler.timesteps if timesteps is None else timesteps + total_num_inference_steps = sum(inference_steps.values()) + self.scheduler.set_timesteps(total_num_inference_steps, device=device) latents = self.prepare_latents( effnet_features_shape, @@ -439,29 +487,13 @@ def __call__( self.scheduler, ) # print(generator_timesteps_tensor) - for i, t in enumerate(self.progress_bar(generator_timesteps_tensor)): - ratio = (t / generator_timesteps_tensor.max()).to(dtype=dtype) - # print(torch.cat([latents] * 2).shape, latents.dtype) - # print(ratio.expand(num_images_per_prompt * 2).shape, ratio.dtype) - # print(text_encoder_hidden_states.shape, text_encoder_hidden_states.dtype) - predicted_image_embedding = self.generator( - torch.cat([latents] * 2) if do_classifier_free_guidance else latents, - r=ratio.expand(num_images_per_prompt * 2) if do_classifier_free_guidance else ratio, - c=text_encoder_hidden_states, + t_start = 1.0 + for t_end, steps in inference_steps.items(): + steps = torch.linspace(t_start, t_end, steps, dtype=dtype, device=device) + latents = self.inference_loop( + latents, steps, predicted_image_embeddings, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator ) - - if do_classifier_free_guidance: - predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) - predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( - predicted_image_embedding_text - predicted_image_embedding_uncond - ) - - latents = self.scheduler.step( - predicted_image_embedding, - timestep=t, - sample=latents, - generator=generator, - ).prev_sample + t_start = t_end if output_type not in ["pt", "np"]: raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") From 32817982ac251a7eeea6c7fdfa9fd486a0feffac Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Mon, 3 Jul 2023 03:46:54 +0200 Subject: [PATCH 039/181] Pipeline works with our diffuzz code --- scripts/wuerstchen_pipeline_test.py | 34 +++- src/diffusers/models/vq_model.py | 2 +- src/diffusers/pipelines/wuerstchen/diffuzz.py | 123 +++++++++++++++ src/diffusers/pipelines/wuerstchen/modules.py | 6 +- .../wuerstchen/pipeline_wuerstchen.py | 147 ++++++++++++------ src/diffusers/pipelines/wuerstchen/vqgan.py | 144 +++++++++++++++++ 6 files changed, 403 insertions(+), 53 deletions(-) create mode 100644 src/diffusers/pipelines/wuerstchen/diffuzz.py create mode 100644 src/diffusers/pipelines/wuerstchen/vqgan.py diff --git a/scripts/wuerstchen_pipeline_test.py b/scripts/wuerstchen_pipeline_test.py index 1a47b9c1d89a..c7a1c6a1679a 100644 --- a/scripts/wuerstchen_pipeline_test.py +++ b/scripts/wuerstchen_pipeline_test.py @@ -1,10 +1,40 @@ +import os +import numpy as np import torch +from PIL import Image from diffusers import WuerstchenPriorPipeline, WuerstchenGeneratorPipeline + +def numpy_to_pil(images: np.ndarray) -> list[Image]: + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + prior_pipeline = WuerstchenPriorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\kashif\\WuerstchenPriorPipeline", torch_dtype=torch.float16) generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\kashif\\WuerstchenGeneratorPipeline", torch_dtype=torch.float16) prior_pipeline = prior_pipeline.to("cuda") generator_pipeline = generator_pipeline.to("cuda") -prior_output = prior_pipeline("An image of a squirrel in Picasso style") -generator_output = generator_pipeline(prior_output.image_embeds, prior_output.text_embeds) +negative_prompt = "low resolution, low detail, bad quality, blurry" +# negative_prompt = "" +# caption = "Bee flying out of a glass jar in a green and red leafy basket, glass and lens flare, diffuse lighting elegant" +# caption = "princess | centered| key visual| intricate| highly detailed| breathtaking beauty| precise lineart| vibrant| comprehensive cinematic| Carne Griffiths| Conrad Roset" +caption = input("Prompt please: ") +while caption != "q": + prior_output = prior_pipeline(caption, num_images_per_prompt=4, negative_prompt=negative_prompt) + generator_output = generator_pipeline(prior_output.image_embeds, prior_output.text_embeds, output_type="np").images + images = numpy_to_pil(generator_output) + + os.makedirs("samples", exist_ok=True) + for i, image in enumerate(images): + image.save(os.path.join("samples", caption.replace(" ", "_").replace("|", "") + f"_{i}.png")) + + caption = input("Prompt please: ") + diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index f56acfdf8e94..fcce5ba880f4 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -119,7 +119,7 @@ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOut return VQEncoderOutput(latents=h) def decode( - self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True + self, h: torch.FloatTensor, force_not_quantize: bool = True, return_dict: bool = True ) -> Union[DecoderOutput, torch.FloatTensor]: if not force_not_quantize: quant, _, _ = self.quantize(h) diff --git a/src/diffusers/pipelines/wuerstchen/diffuzz.py b/src/diffusers/pipelines/wuerstchen/diffuzz.py new file mode 100644 index 000000000000..25c3db27dfd2 --- /dev/null +++ b/src/diffusers/pipelines/wuerstchen/diffuzz.py @@ -0,0 +1,123 @@ +import torch + + +# Samplers -------------------------------------------------------------------- +class SimpleSampler(): + def __init__(self, diffuzz): + self.current_step = -1 + self.diffuzz = diffuzz + + def __call__(self, *args, **kwargs): + self.current_step += 1 + return self.step(*args, **kwargs) + + def init_x(self, shape): + return torch.randn(*shape, device=self.diffuzz.device) + + def step(self, x, t, t_prev, noise): + raise NotImplementedError("You should override the 'apply' function.") + + +class DDPMSampler(SimpleSampler): + def step(self, x, t, t_prev, noise): + alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) + alpha_cumprod_prev = self.diffuzz._alpha_cumprod(t_prev).view(t_prev.size(0), *[1 for _ in x.shape[1:]]) + alpha = (alpha_cumprod / alpha_cumprod_prev) + + mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt()) + std = ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * torch.randn_like(mu) + return mu + std * (t_prev != 0).float().view(t_prev.size(0), *[1 for _ in x.shape[1:]]) + + +class DDIMSampler(SimpleSampler): + def step(self, x, t, t_prev, noise): + alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) + alpha_cumprod_prev = self.diffuzz._alpha_cumprod(t_prev).view(t_prev.size(0), *[1 for _ in x.shape[1:]]) + + x0 = (x - (1 - alpha_cumprod).sqrt() * noise) / (alpha_cumprod).sqrt() + dp_xt = (1 - alpha_cumprod_prev).sqrt() + return (alpha_cumprod_prev).sqrt() * x0 + dp_xt * noise + + +sampler_dict = { + 'ddpm': DDPMSampler, + 'ddim': DDIMSampler, +} + + +# Custom simplified foward/backward diffusion (cosine schedule) +class Diffuzz(): + def __init__(self, s=0.008, device="cpu", cache_steps=None, scaler=1): + self.device = device + self.s = torch.tensor([s]).to(device) + self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 + self.scaler = scaler + self.cached_steps = None + if cache_steps is not None: + self.cached_steps = self._alpha_cumprod(torch.linspace(0, 1, cache_steps, device=device)) + + def _alpha_cumprod(self, t): + if self.cached_steps is None: + if self.scaler > 1: + t = 1 - (1 - t) ** self.scaler + elif self.scaler < 1: + t = t ** self.scaler + alpha_cumprod = torch.cos((t + self.s) / (1 + self.s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod + return alpha_cumprod.clamp(0.0001, 0.9999) + else: + return self.cached_steps[t.mul(len(self.cached_steps) - 1).long()] + + def diffuse(self, x, t, noise=None): # t -> [0, 1] + if noise is None: + noise = torch.randn_like(x) + alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) + return alpha_cumprod.sqrt() * x + (1 - alpha_cumprod).sqrt() * noise, noise + + def undiffuse(self, x, t, t_prev, noise, sampler=None): + if sampler is None: + sampler = DDPMSampler(self) + return sampler(x, t, t_prev, noise) + + def sample(self, model, model_inputs, shape, mask=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, + unconditional_inputs=None, sampler='ddpm', half=False): + r_range = torch.linspace(t_start, t_end, timesteps + 1)[:, None].expand(-1, shape[ + 0] if x_init is None else x_init.size(0)).to(self.device) + if isinstance(sampler, str): + if sampler in sampler_dict: + sampler = sampler_dict[sampler](self) + else: + raise ValueError( + f"If sampler is a string it must be one of the supported samplers: {list(sampler_dict.keys())}") + elif issubclass(sampler, SimpleSampler): + sampler = sampler(self) + else: + raise ValueError("Sampler should be either a string or a SimpleSampler object.") + preds = [] + x = sampler.init_x(shape) if x_init is None or mask is not None else x_init.clone() + if half: + r_range = r_range.half() + x = x.half() + if cfg is not None: + if unconditional_inputs is None: + unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()} + model_inputs = {k: torch.cat([v, v_u]) for (k, v), (k_u, v_u) in + zip(model_inputs.items(), unconditional_inputs.items())} + for i in range(0, timesteps): + if mask is not None and x_init is not None: + x_renoised, _ = self.diffuse(x_init, r_range[i]) + x = x * mask + x_renoised * (1 - mask) + + if cfg is not None: + pred_noise, pred_noise_unconditional = model(torch.cat([x] * 2), torch.cat([r_range[i]] * 2), + **model_inputs).chunk(2) + pred_noise = torch.lerp(pred_noise_unconditional, pred_noise, cfg) + else: + pred_noise = model(x, r_range[i], **model_inputs) + + x = self.undiffuse(x, r_range[i], r_range[i + 1], pred_noise, sampler=sampler) + preds.append(x) + return preds + + def p2_weight(self, t, k=1.0, gamma=1.0): + alpha_cumprod = self._alpha_cumprod(t) + return (k + alpha_cumprod / (1 - alpha_cumprod)) ** -gamma \ No newline at end of file diff --git a/src/diffusers/pipelines/wuerstchen/modules.py b/src/diffusers/pipelines/wuerstchen/modules.py index 211fa92ccb1a..df34ad0e7d4b 100644 --- a/src/diffusers/pipelines/wuerstchen/modules.py +++ b/src/diffusers/pipelines/wuerstchen/modules.py @@ -342,10 +342,11 @@ def _down_encode(self, x, r_embed, effnet, clip): for block in down_block: if isinstance(block, ResBlockStageB): if effnet_c is None and self.effnet_mappers[i] is not None: + dtype = effnet.dtype effnet_c = self.effnet_mappers[i]( nn.functional.interpolate( effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True - ) + ).to(dtype) ) skip = effnet_c if self.effnet_mappers[i] is not None else None x = block(x, skip) @@ -365,10 +366,11 @@ def _up_decode(self, level_outputs, r_embed, effnet, clip): for j, block in enumerate(up_block): if isinstance(block, ResBlockStageB): if effnet_c is None and self.effnet_mappers[len(self.down_blocks) + i] is not None: + dtype = effnet.dtype effnet_c = self.effnet_mappers[len(self.down_blocks) + i]( nn.functional.interpolate( effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True - ) + ).to(dtype) ) skip = level_outputs[i] if j == 0 and i > 0 else None if effnet_c is not None: diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 9c87cca64c1b..558327d4310a 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -15,8 +15,10 @@ from dataclasses import dataclass from typing import List, Optional, Union +from PIL import Image import numpy as np import torch +from PIL.Image import Image from transformers import CLIPTextModel, CLIPTokenizer from ...models import PaellaVQModel @@ -25,6 +27,9 @@ from ..pipeline_utils import DiffusionPipeline from .modules import DiffNeXt, Prior, EfficientNetEncoder +from .diffuzz import Diffuzz +from .vqgan import VQModel + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -42,7 +47,10 @@ ``` """ -default_inference_steps = {2 / 3: 20, 0.0: 10} + +default_inference_steps_c = {2 / 3: 20, 0.0: 10} +# default_inference_steps_c = {0.0: 60} +default_inference_steps_b = {0.0: 12} class WuerstchenPipeline(DiffusionPipeline): @@ -78,7 +86,7 @@ class WuerstchenGeneratorPipelineOutput(BaseOutput): Clip text embeddings for unconditional tokens """ - image_embeds: Union[torch.FloatTensor, np.ndarray] + images: Union[torch.FloatTensor, np.ndarray] class WuerstchenPriorPipeline(DiffusionPipeline): @@ -115,6 +123,7 @@ def __init__( prior=prior, scheduler=scheduler, ) + self.diffuzz = Diffuzz(device="cuda") self.register_to_config() def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): @@ -206,7 +215,7 @@ def _encode_prompt( # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + text_encoder_hidden_states = torch.cat([text_encoder_hidden_states, uncond_text_encoder_hidden_states]) return text_encoder_hidden_states @@ -232,9 +241,11 @@ def _execution_device(self): def inference_loop( self, latents, steps, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator ): - for t in self.progress_bar(steps): + print(steps) + print(steps[:-1]) + for i, t in enumerate(self.progress_bar(steps[:-1])): # print(torch.cat([latents] * 2).shape, latents.dtype) - # print(ratio.expand(num_images_per_prompt * 2).shape, ratio.dtype) + # print(t.expand(latents.size(0) * 2).shape, t.dtype) # print(text_encoder_hidden_states.shape, text_encoder_hidden_states.dtype) predicted_image_embedding = self.prior( torch.cat([latents] * 2) if do_classifier_free_guidance else latents, @@ -242,20 +253,26 @@ def inference_loop( c=text_encoder_hidden_states, ) + # print(t.expand(latents.size(0) * 2)) + # print(i, predicted_image_embedding[0, 0, :4, :4]) + # print(text_encoder_hidden_states[0, 4, :4]) + if do_classifier_free_guidance: - predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( predicted_image_embedding_text - predicted_image_embedding_uncond ) # print(t) - timestep = (t * 999).cpu().int() - # print(timestep) - latents = self.scheduler.step( - predicted_image_embedding, - timestep=timestep - 1, - sample=latents, - generator=generator, - ).prev_sample + + latents = self.diffuzz.undiffuse(latents, t[None], steps[i+1][None], predicted_image_embedding).to(dtype=t.dtype) + # timestep = (t * 999).cpu().int() + # # print(timestep) + # latents = self.scheduler.step( + # predicted_image_embedding, + # timestep=timestep - 1, + # sample=latents, + # generator=generator, + # ).prev_sample return latents @@ -266,7 +283,7 @@ def __call__( height: int = 1024, width: int = 1024, inference_steps: dict = None, - guidance_scale: float = 7.0, + guidance_scale: float = 8.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -279,7 +296,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 if inference_steps is None: - inference_steps = default_inference_steps + inference_steps = default_inference_steps_c if negative_prompt is None: negative_prompt = "" @@ -305,18 +322,37 @@ def __call__( total_num_inference_steps = sum(inference_steps.values()) self.scheduler.set_timesteps(total_num_inference_steps, device=device) - latents = self.prepare_latents( - effnet_features_shape, - dtype, - device, - generator, - latents, - self.scheduler, - ) + def seed_everything(seed: int): + import random, os + import numpy as np + import torch + + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + + seed_everything(42) + + # latents = self.prepare_latents( + # effnet_features_shape, + # dtype, + # device, + # generator, + # latents, + # self.scheduler, + # ) + + latents = torch.randn(effnet_features_shape, device=device) + print(latents[0, 0, :4, :4]) + latents = latents.to(dtype=dtype) t_start = 1.0 for t_end, steps in inference_steps.items(): - steps = torch.linspace(t_start, t_end, steps, dtype=dtype, device=device) + steps = torch.linspace(t_start, t_end, steps+1, dtype=dtype, device=device) latents = self.inference_loop( latents, steps, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator ) @@ -371,6 +407,13 @@ def __init__( vqgan=vqgan, efficient_net=efficient_net, ) + self.diffuzz = Diffuzz(device="cuda") + + self.vqmodel = VQModel() + state_dict = torch.load(r"C:\Users\d6582\Documents\ml\diffusers\scripts\models\vqgan_f4_v1_500k.pt")["state_dict"] + self.vqmodel.load_state_dict(state_dict) + self.vqmodel.to("cuda").to(torch.float16) + self.register_to_config() def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): @@ -413,6 +456,8 @@ def check_inputs(self, predicted_image_embeddings, text_encoder_hidden_states, d if do_classifier_free_guidance: assert predicted_image_embeddings.size(0) == text_encoder_hidden_states.size(0) // 2, f"'text_encoder_hidden_states' must be double the size of 'predicted_image_embeddings' in the first dimension, but {predicted_image_embeddings.size(0)} != {text_encoder_hidden_states.size(0)}." else: + if predicted_image_embeddings.size(0) * 2 == text_encoder_hidden_states.size(0): + text_encoder_hidden_states = text_encoder_hidden_states.chunk(2)[0] assert predicted_image_embeddings.size(0) == text_encoder_hidden_states.size(0), f"'text_encoder_hidden_states' must be the size of 'predicted_image_embeddings' in the first dimension, but {predicted_image_embeddings.size(0)} != {text_encoder_hidden_states.size(0)}." return predicted_image_embeddings, text_encoder_hidden_states @@ -421,30 +466,34 @@ def check_inputs(self, predicted_image_embeddings, text_encoder_hidden_states, d def inference_loop( self, latents, steps, predicted_effnet_latents, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator ): - for t in self.progress_bar(steps): - print(torch.cat([latents] * 2).shape, latents.dtype) - print(t.expand(latents.size(0) * 2).shape, t.dtype) - print(text_encoder_hidden_states.shape, text_encoder_hidden_states.dtype) + for i, t in enumerate(self.progress_bar(steps[:-1])): + # print(torch.cat([latents] * 2).shape, latents.dtype, latents.device) + # print(t.expand(latents.size(0) * 2).shape, t.dtype, t.device) + # print(text_encoder_hidden_states.shape, text_encoder_hidden_states.dtype, text_encoder_hidden_states.device) + # print(predicted_effnet_latents.shape, predicted_effnet_latents.dtype, predicted_effnet_latents.device) predicted_image_embedding = self.generator( torch.cat([latents] * 2) if do_classifier_free_guidance else latents, - r=t.expand(latents.size(0) * 2) if do_classifier_free_guidance else t, - effnet=predicted_effnet_latents, clip=text_encoder_hidden_states, + r=t.expand(latents.size(0) * 2) if do_classifier_free_guidance else t[None], + effnet=torch.cat([predicted_effnet_latents, torch.zeros_like(predicted_effnet_latents)]) if do_classifier_free_guidance else predicted_effnet_latents, + clip=text_encoder_hidden_states, ) if do_classifier_free_guidance: - predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( predicted_image_embedding_text - predicted_image_embedding_uncond ) # print(t) - timestep = (t * 999).cpu().int() - # print(timestep) - latents = self.scheduler.step( - predicted_image_embedding, - timestep=timestep - 1, - sample=latents, - generator=generator, - ).prev_sample + latents = self.diffuzz.undiffuse(latents, t[None], steps[i+1][None], predicted_image_embedding).to(dtype=t.dtype) + + # timestep = (t * 999).cpu().int() + # # print(timestep) + # latents = self.scheduler.step( + # predicted_image_embedding, + # timestep=timestep - 1, + # sample=latents, + # generator=generator, + # ).prev_sample return latents @@ -454,8 +503,7 @@ def __call__( predicted_image_embeddings: torch.Tensor, text_encoder_hidden_states: torch.Tensor, inference_steps: dict = None, - guidance_scale: float = 7.0, - num_images_per_prompt: Optional[int] = 1, + guidance_scale: float = 0., generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pt", # pt only @@ -466,14 +514,14 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 if inference_steps is None: - inference_steps = default_inference_steps + inference_steps = default_inference_steps_b predicted_image_embeddings, text_encoder_hidden_states = self.check_inputs(predicted_image_embeddings, text_encoder_hidden_states, do_classifier_free_guidance, device) dtype = text_encoder_hidden_states.dtype latent_height = int(predicted_image_embeddings.size(2) * (256 / 24)) latent_width = int(predicted_image_embeddings.size(3) * (256 / 24)) - effnet_features_shape = (num_images_per_prompt, 4, latent_height, latent_width) + effnet_features_shape = (predicted_image_embeddings.size(0), 4, latent_height, latent_width) total_num_inference_steps = sum(inference_steps.values()) self.scheduler.set_timesteps(total_num_inference_steps, device=device) @@ -489,19 +537,22 @@ def __call__( # print(generator_timesteps_tensor) t_start = 1.0 for t_end, steps in inference_steps.items(): - steps = torch.linspace(t_start, t_end, steps, dtype=dtype, device=device) + steps = torch.linspace(t_start, t_end, steps+1, dtype=dtype, device=device) latents = self.inference_loop( latents, steps, predicted_image_embeddings, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator ) t_start = t_end + images = self.vqmodel.decode(latents).clamp(0, 1) + # images = self.vqgan.decode(latents).sample + if output_type not in ["pt", "np"]: raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") if output_type == "np": - latents = latents.cpu().numpy() + images = images.permute(0, 2, 3, 1).cpu().numpy() if not return_dict: - return (latents, text_encoder_hidden_states) + return images - return WuerstchenGeneratorPipelineOutput(latents) + return WuerstchenGeneratorPipelineOutput(images) diff --git a/src/diffusers/pipelines/wuerstchen/vqgan.py b/src/diffusers/pipelines/wuerstchen/vqgan.py new file mode 100644 index 000000000000..935023e1ff6e --- /dev/null +++ b/src/diffusers/pipelines/wuerstchen/vqgan.py @@ -0,0 +1,144 @@ +import torch +from torch import nn +import numpy as np +import math +from tqdm import tqdm +import time +from torchtools.nn import VectorQuantize + +class ResBlock(nn.Module): + def __init__(self, c, c_hidden): + super().__init__() + # depthwise/attention + self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.depthwise = nn.Sequential( + nn.ReplicationPad2d(1), + nn.Conv2d(c, c, kernel_size=3, groups=c) + ) + + # channelwise + self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c, c_hidden), + nn.GELU(), + nn.Linear(c_hidden, c), + ) + + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) + + # Init weights + def _basic_init(module): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + + def _norm(self, x, norm): + return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + def forward(self, x): + mods = self.gammas + + x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] + x = x + self.depthwise(x_temp) * mods[2] + + x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] + x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] + + return x + +class VQModel(nn.Module): + def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, scale_factor=0.3764): # 1.0 + super().__init__() + self.c_latent = c_latent + self.scale_factor = scale_factor + c_levels = [c_hidden//(2**i) for i in reversed(range(levels))] + + # Encoder blocks + self.in_block = nn.Sequential( + nn.PixelUnshuffle(2), + nn.Conv2d(3*4, c_levels[0], kernel_size=1) + ) + down_blocks = [] + for i in range(levels): + if i > 0: + down_blocks.append(nn.Conv2d(c_levels[i-1], c_levels[i], kernel_size=4, stride=2, padding=1)) + block = ResBlock(c_levels[i], c_levels[i]*4) + down_blocks.append(block) + down_blocks.append(nn.Sequential( + nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 + )) + self.down_blocks = nn.Sequential(*down_blocks) + self.down_blocks[0] + + self.codebook_size = codebook_size + self.vquantizer = VectorQuantize(c_latent, k=codebook_size) + + # Decoder blocks + up_blocks = [nn.Sequential( + nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) + )] + for i in range(levels): + for j in range(bottleneck_blocks if i == 0 else 1): + block = ResBlock(c_levels[levels-1-i], c_levels[levels-1-i]*4) + up_blocks.append(block) + if i < levels-1: + up_blocks.append(nn.ConvTranspose2d(c_levels[levels-1-i], c_levels[levels-2-i], kernel_size=4, stride=2, padding=1)) + self.up_blocks = nn.Sequential(*up_blocks) + self.out_block = nn.Sequential( + nn.Conv2d(c_levels[0], 3*4, kernel_size=1), + nn.PixelShuffle(2), + ) + + def encode(self, x): + x = self.in_block(x) + x = self.down_blocks(x) + qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) + return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 + + def decode(self, x): + x = x * self.scale_factor + x = self.up_blocks(x) + x = self.out_block(x) + return x + + def decode_indices(self, x): + x = self.vquantizer.idx2vq(x, dim=1) + x = self.up_blocks(x) + x = self.out_block(x) + return x + + def forward(self, x, quantize=False): + qe, x, _, vq_loss = self.encode(x, quantize) + x = self.decode(qe) + return x, vq_loss + +class Discriminator(nn.Module): + def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6): + super().__init__() + d = max(depth - 3, 3) + layers = [ + nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), + nn.LeakyReLU(0.2), + ] + for i in range(depth - 1): + c_in = c_hidden // (2 ** max((d - i), 0)) + c_out = c_hidden // (2 ** max((d - 1 - i), 0)) + layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) + layers.append(nn.InstanceNorm2d(c_out)) + layers.append(nn.LeakyReLU(0.2)) + self.encoder = nn.Sequential(*layers) + self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) + self.logits = nn.Sigmoid() + + def forward(self, x, cond=None): + x = self.encoder(x) + if cond is not None: + cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) + x = torch.cat([x, cond], dim=1) + x = self.shuffle(x) + x = self.logits(x) + return x \ No newline at end of file From 59667d97b4b4f3c04d509ecdf88360a7d8a1f6a9 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 4 Jul 2023 06:12:48 +0200 Subject: [PATCH 040/181] apply scale factor --- src/diffusers/models/vq_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index fcce5ba880f4..1746dcf3bf88 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -111,7 +111,7 @@ def __init__( def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: h = self.in_block(x) - h = self.down_blocks(h) + h = self.down_blocks(h) / self.config.scale_factor if not return_dict: return (h,) @@ -122,9 +122,10 @@ def decode( self, h: torch.FloatTensor, force_not_quantize: bool = True, return_dict: bool = True ) -> Union[DecoderOutput, torch.FloatTensor]: if not force_not_quantize: - quant, _, _ = self.quantize(h) + quant, _, _ = self.vquantizer(h * self.config.scale_factor) else: - quant = h + quant = h * self.config.scale_factor + x = self.up_blocks(quant) dec = self.out_block(x) if not return_dict: From 156901f8157f8775f7c1dcdfc439d11468e4742d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 4 Jul 2023 06:42:22 +0200 Subject: [PATCH 041/181] removed vqgan.py --- .../wuerstchen/pipeline_wuerstchen.py | 74 ++++++--- src/diffusers/pipelines/wuerstchen/vqgan.py | 144 ------------------ 2 files changed, 50 insertions(+), 168 deletions(-) delete mode 100644 src/diffusers/pipelines/wuerstchen/vqgan.py diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 558327d4310a..64b29fd9c0c3 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -28,7 +28,6 @@ from .modules import DiffNeXt, Prior, EfficientNetEncoder from .diffuzz import Diffuzz -from .vqgan import VQModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -264,7 +263,9 @@ def inference_loop( ) # print(t) - latents = self.diffuzz.undiffuse(latents, t[None], steps[i+1][None], predicted_image_embedding).to(dtype=t.dtype) + latents = self.diffuzz.undiffuse(latents, t[None], steps[i + 1][None], predicted_image_embedding).to( + dtype=t.dtype + ) # timestep = (t * 999).cpu().int() # # print(timestep) # latents = self.scheduler.step( @@ -328,7 +329,7 @@ def seed_everything(seed: int): import torch random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) + os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -352,7 +353,7 @@ def seed_everything(seed: int): t_start = 1.0 for t_end, steps in inference_steps.items(): - steps = torch.linspace(t_start, t_end, steps+1, dtype=dtype, device=device) + steps = torch.linspace(t_start, t_end, steps + 1, dtype=dtype, device=device) latents = self.inference_loop( latents, steps, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator ) @@ -409,11 +410,6 @@ def __init__( ) self.diffuzz = Diffuzz(device="cuda") - self.vqmodel = VQModel() - state_dict = torch.load(r"C:\Users\d6582\Documents\ml\diffusers\scripts\models\vqgan_f4_v1_500k.pt")["state_dict"] - self.vqmodel.load_state_dict(state_dict) - self.vqmodel.to("cuda").to(torch.float16) - self.register_to_config() def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): @@ -445,26 +441,45 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device - def check_inputs(self, predicted_image_embeddings, text_encoder_hidden_states, do_classifier_free_guidance, device): + def check_inputs( + self, predicted_image_embeddings, text_encoder_hidden_states, do_classifier_free_guidance, device + ): if not isinstance(text_encoder_hidden_states, torch.Tensor): - raise TypeError(f"'text_encoder_hidden_states' must be of type 'torch.Tensor', but got {type(predicted_image_embeddings)}.") + raise TypeError( + f"'text_encoder_hidden_states' must be of type 'torch.Tensor', but got {type(predicted_image_embeddings)}." + ) if isinstance(predicted_image_embeddings, np.ndarray): - predicted_image_embeddings = torch.Tensor(predicted_image_embeddings, device=device).to(dtype=text_encoder_hidden_states.dtype) + predicted_image_embeddings = torch.Tensor(predicted_image_embeddings, device=device).to( + dtype=text_encoder_hidden_states.dtype + ) if not isinstance(predicted_image_embeddings, torch.Tensor): - raise TypeError(f"'predicted_image_embeddings' must be of type 'torch.Tensor' or 'np.array', but got {type(predicted_image_embeddings)}.") + raise TypeError( + f"'predicted_image_embeddings' must be of type 'torch.Tensor' or 'np.array', but got {type(predicted_image_embeddings)}." + ) if do_classifier_free_guidance: - assert predicted_image_embeddings.size(0) == text_encoder_hidden_states.size(0) // 2, f"'text_encoder_hidden_states' must be double the size of 'predicted_image_embeddings' in the first dimension, but {predicted_image_embeddings.size(0)} != {text_encoder_hidden_states.size(0)}." + assert ( + predicted_image_embeddings.size(0) == text_encoder_hidden_states.size(0) // 2 + ), f"'text_encoder_hidden_states' must be double the size of 'predicted_image_embeddings' in the first dimension, but {predicted_image_embeddings.size(0)} != {text_encoder_hidden_states.size(0)}." else: if predicted_image_embeddings.size(0) * 2 == text_encoder_hidden_states.size(0): text_encoder_hidden_states = text_encoder_hidden_states.chunk(2)[0] - assert predicted_image_embeddings.size(0) == text_encoder_hidden_states.size(0), f"'text_encoder_hidden_states' must be the size of 'predicted_image_embeddings' in the first dimension, but {predicted_image_embeddings.size(0)} != {text_encoder_hidden_states.size(0)}." + assert predicted_image_embeddings.size(0) == text_encoder_hidden_states.size( + 0 + ), f"'text_encoder_hidden_states' must be the size of 'predicted_image_embeddings' in the first dimension, but {predicted_image_embeddings.size(0)} != {text_encoder_hidden_states.size(0)}." return predicted_image_embeddings, text_encoder_hidden_states @torch.no_grad() def inference_loop( - self, latents, steps, predicted_effnet_latents, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator + self, + latents, + steps, + predicted_effnet_latents, + text_encoder_hidden_states, + do_classifier_free_guidance, + guidance_scale, + generator, ): for i, t in enumerate(self.progress_bar(steps[:-1])): # print(torch.cat([latents] * 2).shape, latents.dtype, latents.device) @@ -474,7 +489,9 @@ def inference_loop( predicted_image_embedding = self.generator( torch.cat([latents] * 2) if do_classifier_free_guidance else latents, r=t.expand(latents.size(0) * 2) if do_classifier_free_guidance else t[None], - effnet=torch.cat([predicted_effnet_latents, torch.zeros_like(predicted_effnet_latents)]) if do_classifier_free_guidance else predicted_effnet_latents, + effnet=torch.cat([predicted_effnet_latents, torch.zeros_like(predicted_effnet_latents)]) + if do_classifier_free_guidance + else predicted_effnet_latents, clip=text_encoder_hidden_states, ) @@ -484,7 +501,9 @@ def inference_loop( predicted_image_embedding_text - predicted_image_embedding_uncond ) # print(t) - latents = self.diffuzz.undiffuse(latents, t[None], steps[i+1][None], predicted_image_embedding).to(dtype=t.dtype) + latents = self.diffuzz.undiffuse(latents, t[None], steps[i + 1][None], predicted_image_embedding).to( + dtype=t.dtype + ) # timestep = (t * 999).cpu().int() # # print(timestep) @@ -503,7 +522,7 @@ def __call__( predicted_image_embeddings: torch.Tensor, text_encoder_hidden_states: torch.Tensor, inference_steps: dict = None, - guidance_scale: float = 0., + guidance_scale: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pt", # pt only @@ -516,7 +535,9 @@ def __call__( if inference_steps is None: inference_steps = default_inference_steps_b - predicted_image_embeddings, text_encoder_hidden_states = self.check_inputs(predicted_image_embeddings, text_encoder_hidden_states, do_classifier_free_guidance, device) + predicted_image_embeddings, text_encoder_hidden_states = self.check_inputs( + predicted_image_embeddings, text_encoder_hidden_states, do_classifier_free_guidance, device + ) dtype = text_encoder_hidden_states.dtype latent_height = int(predicted_image_embeddings.size(2) * (256 / 24)) @@ -537,14 +558,19 @@ def __call__( # print(generator_timesteps_tensor) t_start = 1.0 for t_end, steps in inference_steps.items(): - steps = torch.linspace(t_start, t_end, steps+1, dtype=dtype, device=device) + steps = torch.linspace(t_start, t_end, steps + 1, dtype=dtype, device=device) latents = self.inference_loop( - latents, steps, predicted_image_embeddings, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator + latents, + steps, + predicted_image_embeddings, + text_encoder_hidden_states, + do_classifier_free_guidance, + guidance_scale, + generator, ) t_start = t_end - images = self.vqmodel.decode(latents).clamp(0, 1) - # images = self.vqgan.decode(latents).sample + images = self.vqgan.decode(latents).sample if output_type not in ["pt", "np"]: raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") diff --git a/src/diffusers/pipelines/wuerstchen/vqgan.py b/src/diffusers/pipelines/wuerstchen/vqgan.py deleted file mode 100644 index 935023e1ff6e..000000000000 --- a/src/diffusers/pipelines/wuerstchen/vqgan.py +++ /dev/null @@ -1,144 +0,0 @@ -import torch -from torch import nn -import numpy as np -import math -from tqdm import tqdm -import time -from torchtools.nn import VectorQuantize - -class ResBlock(nn.Module): - def __init__(self, c, c_hidden): - super().__init__() - # depthwise/attention - self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) - self.depthwise = nn.Sequential( - nn.ReplicationPad2d(1), - nn.Conv2d(c, c, kernel_size=3, groups=c) - ) - - # channelwise - self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) - self.channelwise = nn.Sequential( - nn.Linear(c, c_hidden), - nn.GELU(), - nn.Linear(c_hidden, c), - ) - - self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) - - # Init weights - def _basic_init(module): - if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - self.apply(_basic_init) - - - def _norm(self, x, norm): - return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - - def forward(self, x): - mods = self.gammas - - x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] - x = x + self.depthwise(x_temp) * mods[2] - - x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] - x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] - - return x - -class VQModel(nn.Module): - def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, scale_factor=0.3764): # 1.0 - super().__init__() - self.c_latent = c_latent - self.scale_factor = scale_factor - c_levels = [c_hidden//(2**i) for i in reversed(range(levels))] - - # Encoder blocks - self.in_block = nn.Sequential( - nn.PixelUnshuffle(2), - nn.Conv2d(3*4, c_levels[0], kernel_size=1) - ) - down_blocks = [] - for i in range(levels): - if i > 0: - down_blocks.append(nn.Conv2d(c_levels[i-1], c_levels[i], kernel_size=4, stride=2, padding=1)) - block = ResBlock(c_levels[i], c_levels[i]*4) - down_blocks.append(block) - down_blocks.append(nn.Sequential( - nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), - nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 - )) - self.down_blocks = nn.Sequential(*down_blocks) - self.down_blocks[0] - - self.codebook_size = codebook_size - self.vquantizer = VectorQuantize(c_latent, k=codebook_size) - - # Decoder blocks - up_blocks = [nn.Sequential( - nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) - )] - for i in range(levels): - for j in range(bottleneck_blocks if i == 0 else 1): - block = ResBlock(c_levels[levels-1-i], c_levels[levels-1-i]*4) - up_blocks.append(block) - if i < levels-1: - up_blocks.append(nn.ConvTranspose2d(c_levels[levels-1-i], c_levels[levels-2-i], kernel_size=4, stride=2, padding=1)) - self.up_blocks = nn.Sequential(*up_blocks) - self.out_block = nn.Sequential( - nn.Conv2d(c_levels[0], 3*4, kernel_size=1), - nn.PixelShuffle(2), - ) - - def encode(self, x): - x = self.in_block(x) - x = self.down_blocks(x) - qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) - return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 - - def decode(self, x): - x = x * self.scale_factor - x = self.up_blocks(x) - x = self.out_block(x) - return x - - def decode_indices(self, x): - x = self.vquantizer.idx2vq(x, dim=1) - x = self.up_blocks(x) - x = self.out_block(x) - return x - - def forward(self, x, quantize=False): - qe, x, _, vq_loss = self.encode(x, quantize) - x = self.decode(qe) - return x, vq_loss - -class Discriminator(nn.Module): - def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6): - super().__init__() - d = max(depth - 3, 3) - layers = [ - nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), - nn.LeakyReLU(0.2), - ] - for i in range(depth - 1): - c_in = c_hidden // (2 ** max((d - i), 0)) - c_out = c_hidden // (2 ** max((d - 1 - i), 0)) - layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) - layers.append(nn.InstanceNorm2d(c_out)) - layers.append(nn.LeakyReLU(0.2)) - self.encoder = nn.Sequential(*layers) - self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) - self.logits = nn.Sigmoid() - - def forward(self, x, cond=None): - x = self.encoder(x) - if cond is not None: - cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) - x = torch.cat([x, cond], dim=1) - x = self.shuffle(x) - x = self.logits(x) - return x \ No newline at end of file From ddc9daa50c6639f480a86a9cea38567d37a0acc1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 4 Jul 2023 19:01:55 +0200 Subject: [PATCH 042/181] use cosine schedule --- scripts/convert_wuerstchen.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 930cdb3b4ee9..fb56214a0adb 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -47,13 +47,11 @@ prior_model = Prior(c_in=16, c=1536, c_cond=1024, c_r=64, depth=32, nhead=24).to(device) prior_model.load_state_dict(state_dict["ema_state_dict"]) +# Trained betas for scheduler via cosine +trained_betas = [] # scheduler -scheduler = DDPMScheduler( - beta_schedule="linear", - beta_start=0.0001, - beta_end=0.02, -) +scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2") # Prior pipeline prior_pipeline = WuerstchenPriorPipeline( From 30e4888104bf69fdf9612744db8791ce8c0a607a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 9 Jul 2023 18:17:00 +0200 Subject: [PATCH 043/181] redo the denoising loop --- scripts/convert_wuerstchen.py | 44 --- .../wuerstchen/pipeline_wuerstchen.py | 308 ++++++++++-------- 2 files changed, 176 insertions(+), 176 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index fb56214a0adb..ba74bf376b92 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -70,47 +70,3 @@ scheduler=scheduler, ) generator_pipeline.save_pretrained("kashif/WuerstchenGeneratorPipeline") - - -# WuerstchenPipeline( -# vae=VQGan() -# text_encoder=ClipTextEncoder(), -# prior=prior, -# (image_encoder)=efficient_net, -# ) -# stage C = prior -# stage B = unet -# stage A = vae -# WuerstchenPipeline( -# vae=VQGan() -# text_encoder=ClipTextEncoder(), -# unet = UNet2DConditionModel(), -# prior=prior, -# (image_encoder)=efficient_net, -# ) -# Patrick von Platen4:17 PM -# WuerstchenPipeline( -# vae=VQGan() -# text_encoder=ClipTextEncoder(), -# unet = UNet2DConditionModel(), -# prior=prior, -# tokenizer=CLIPTokenizer, -# (image_encoder)=efficient_net, -# ) -# WuerstchenPipeline( -# vae=VQGan() -# text_encoder=ClipTextEncoder(), -# unet = UNet2DConditionModel(), -# prior=PriorTransformer(), -# tokenizer=CLIPTokenizer, -# (image_encoder)=efficient_net, -# ) -# Patrick von Platen4:20 PM -# WuerstchenPipeline( -# vae=VQGan() -# text_encoder=ClipTextEncoder(), -# unet = NewUNet(), # Paella Style -# prior=NewPrior(), # find good name -# tokenizer=CLIPTokenizer, -# (image_encoder)=efficient_net, -# ) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 64b29fd9c0c3..51cab1e8ff66 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -27,7 +27,7 @@ from ..pipeline_utils import DiffusionPipeline from .modules import DiffNeXt, Prior, EfficientNetEncoder -from .diffuzz import Diffuzz +# from .diffuzz import Diffuzz logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -36,25 +36,21 @@ Examples: ```py >>> import torch - >>> from diffusers import WuerstchenPipeline + >>> from diffusers import WuerstchenPriorPipeline, WuerstchenGeneratorPipeline - >>> pipe = WuerstchenPipeline.from_pretrained("kashif/wuerstchen", torch_dtype=torch.float16) - >>> pipe = pipe.to("cuda") + >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained("kashif/wuerstchen-prior", torch_dtype=torch.float16).to("cuda") + >>> gen_pipe = WuerstchenGeneratorPipeline.from_pretrain("kashif/wuerstchen-gen", torch_dtype=torch.float16).to("cuda") >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" - >>> image = pipe(prompt).images[0] + >>> prior_output = pipe(prompt) + >>> images = gen_pipe(prior_output.image_embeds, prior_output.text_embeds) ``` """ default_inference_steps_c = {2 / 3: 20, 0.0: 10} # default_inference_steps_c = {0.0: 60} -default_inference_steps_b = {0.0: 12} - - -class WuerstchenPipeline(DiffusionPipeline): - unet: DiffNeXt - vqmodel: PaellaVQModel +default_inference_steps_b = {0.0: 30} @dataclass @@ -122,7 +118,7 @@ def __init__( prior=prior, scheduler=scheduler, ) - self.diffuzz = Diffuzz(device="cuda") + # self.diffuzz = Diffuzz(device="cuda") self.register_to_config() def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): @@ -236,46 +232,46 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device - @torch.no_grad() - def inference_loop( - self, latents, steps, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator - ): - print(steps) - print(steps[:-1]) - for i, t in enumerate(self.progress_bar(steps[:-1])): - # print(torch.cat([latents] * 2).shape, latents.dtype) - # print(t.expand(latents.size(0) * 2).shape, t.dtype) - # print(text_encoder_hidden_states.shape, text_encoder_hidden_states.dtype) - predicted_image_embedding = self.prior( - torch.cat([latents] * 2) if do_classifier_free_guidance else latents, - r=t.expand(latents.size(0) * 2) if do_classifier_free_guidance else t, - c=text_encoder_hidden_states, - ) - - # print(t.expand(latents.size(0) * 2)) - # print(i, predicted_image_embedding[0, 0, :4, :4]) - # print(text_encoder_hidden_states[0, 4, :4]) - - if do_classifier_free_guidance: - predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) - predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( - predicted_image_embedding_text - predicted_image_embedding_uncond - ) - # print(t) - - latents = self.diffuzz.undiffuse(latents, t[None], steps[i + 1][None], predicted_image_embedding).to( - dtype=t.dtype - ) - # timestep = (t * 999).cpu().int() - # # print(timestep) - # latents = self.scheduler.step( - # predicted_image_embedding, - # timestep=timestep - 1, - # sample=latents, - # generator=generator, - # ).prev_sample - - return latents + # @torch.no_grad() + # def inference_loop( + # self, latents, steps, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator + # ): + # print(steps) + # print(steps[:-1]) + # for i, t in enumerate(self.progress_bar(steps[:-1])): + # # print(torch.cat([latents] * 2).shape, latents.dtype) + # # print(t.expand(latents.size(0) * 2).shape, t.dtype) + # # print(text_encoder_hidden_states.shape, text_encoder_hidden_states.dtype) + # predicted_image_embedding = self.prior( + # torch.cat([latents] * 2) if do_classifier_free_guidance else latents, + # r=t.expand(latents.size(0) * 2) if do_classifier_free_guidance else t, + # c=text_encoder_hidden_states, + # ) + + # # print(t.expand(latents.size(0) * 2)) + # # print(i, predicted_image_embedding[0, 0, :4, :4]) + # # print(text_encoder_hidden_states[0, 4, :4]) + + # if do_classifier_free_guidance: + # predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) + # predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( + # predicted_image_embedding_text - predicted_image_embedding_uncond + # ) + # # print(t) + + # # latents = self.diffuzz.undiffuse(latents, t[None], steps[i + 1][None], predicted_image_embedding).to( + # # dtype=t.dtype + # # ) + # timestep = (t * 999).cpu().int() + # # print(timestep) + # latents = self.scheduler.step( + # predicted_image_embedding, + # timestep=timestep - 1, + # sample=latents, + # generator=generator, + # ).prev_sample + + # return latents @torch.no_grad() def __call__( @@ -284,7 +280,7 @@ def __call__( height: int = 1024, width: int = 1024, inference_steps: dict = None, - guidance_scale: float = 8.0, + guidance_scale: float = 3.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -322,6 +318,7 @@ def __call__( total_num_inference_steps = sum(inference_steps.values()) self.scheduler.set_timesteps(total_num_inference_steps, device=device) + prior_timesteps_tensor = self.scheduler.timesteps def seed_everything(seed: int): import random, os @@ -338,26 +335,47 @@ def seed_everything(seed: int): seed_everything(42) - # latents = self.prepare_latents( - # effnet_features_shape, - # dtype, - # device, - # generator, - # latents, - # self.scheduler, - # ) - - latents = torch.randn(effnet_features_shape, device=device) - print(latents[0, 0, :4, :4]) - latents = latents.to(dtype=dtype) - - t_start = 1.0 - for t_end, steps in inference_steps.items(): - steps = torch.linspace(t_start, t_end, steps + 1, dtype=dtype, device=device) - latents = self.inference_loop( - latents, steps, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator + latents = self.prepare_latents( + effnet_features_shape, + dtype, + device, + generator, + latents, + self.scheduler, + ) + + # latents = torch.randn(effnet_features_shape, device=device) + # print(latents[0, 0, :4, :4]) + # latents = latents.to(dtype=dtype) + + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + ratio = (t / self.scheduler.config.num_train_timesteps).to(dtype) # between 0 and 1 + predicted_image_embedding = self.prior( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents, + r=ratio.expand(latents.size(0) * 2) if do_classifier_free_guidance else ratio, + c=text_encoder_hidden_states, ) - t_start = t_end + + if do_classifier_free_guidance: + predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + latents = self.scheduler.step( + predicted_image_embedding, + timestep=t, + sample=latents, + generator=generator, + ).prev_sample + + # t_start = 1.0 + # for t_end, steps in inference_steps.items(): + # steps = torch.linspace(t_start, t_end, steps + 1, dtype=dtype, device=device) + # latents = self.inference_loop( + # latents, steps, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator + # ) + # t_start = t_end # normalize the latents latents = latents * 42.0 - 1.0 @@ -408,7 +426,7 @@ def __init__( vqgan=vqgan, efficient_net=efficient_net, ) - self.diffuzz = Diffuzz(device="cuda") + # self.diffuzz = Diffuzz(device="cuda") self.register_to_config() @@ -470,51 +488,51 @@ def check_inputs( return predicted_image_embeddings, text_encoder_hidden_states - @torch.no_grad() - def inference_loop( - self, - latents, - steps, - predicted_effnet_latents, - text_encoder_hidden_states, - do_classifier_free_guidance, - guidance_scale, - generator, - ): - for i, t in enumerate(self.progress_bar(steps[:-1])): - # print(torch.cat([latents] * 2).shape, latents.dtype, latents.device) - # print(t.expand(latents.size(0) * 2).shape, t.dtype, t.device) - # print(text_encoder_hidden_states.shape, text_encoder_hidden_states.dtype, text_encoder_hidden_states.device) - # print(predicted_effnet_latents.shape, predicted_effnet_latents.dtype, predicted_effnet_latents.device) - predicted_image_embedding = self.generator( - torch.cat([latents] * 2) if do_classifier_free_guidance else latents, - r=t.expand(latents.size(0) * 2) if do_classifier_free_guidance else t[None], - effnet=torch.cat([predicted_effnet_latents, torch.zeros_like(predicted_effnet_latents)]) - if do_classifier_free_guidance - else predicted_effnet_latents, - clip=text_encoder_hidden_states, - ) - - if do_classifier_free_guidance: - predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) - predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( - predicted_image_embedding_text - predicted_image_embedding_uncond - ) - # print(t) - latents = self.diffuzz.undiffuse(latents, t[None], steps[i + 1][None], predicted_image_embedding).to( - dtype=t.dtype - ) - - # timestep = (t * 999).cpu().int() - # # print(timestep) - # latents = self.scheduler.step( - # predicted_image_embedding, - # timestep=timestep - 1, - # sample=latents, - # generator=generator, - # ).prev_sample - - return latents + # @torch.no_grad() + # def inference_loop( + # self, + # latents, + # steps, + # predicted_effnet_latents, + # text_encoder_hidden_states, + # do_classifier_free_guidance, + # guidance_scale, + # generator, + # ): + # for i, t in enumerate(self.progress_bar(steps[:-1])): + # # print(torch.cat([latents] * 2).shape, latents.dtype, latents.device) + # # print(t.expand(latents.size(0) * 2).shape, t.dtype, t.device) + # # print(text_encoder_hidden_states.shape, text_encoder_hidden_states.dtype, text_encoder_hidden_states.device) + # # print(predicted_effnet_latents.shape, predicted_effnet_latents.dtype, predicted_effnet_latents.device) + # predicted_image_embedding = self.generator( + # torch.cat([latents] * 2) if do_classifier_free_guidance else latents, + # r=t.expand(latents.size(0) * 2) if do_classifier_free_guidance else t[None], + # effnet=torch.cat([predicted_effnet_latents, torch.zeros_like(predicted_effnet_latents)]) + # if do_classifier_free_guidance + # else predicted_effnet_latents, + # clip=text_encoder_hidden_states, + # ) + + # if do_classifier_free_guidance: + # predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) + # predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( + # predicted_image_embedding_text - predicted_image_embedding_uncond + # ) + # # print(t) + # # latents = self.diffuzz.undiffuse(latents, t[None], steps[i + 1][None], predicted_image_embedding).to( + # # dtype=t.dtype + # # ) + + # timestep = (t * 999).cpu().int() + # # print(timestep) + # latents = self.scheduler.step( + # predicted_image_embedding, + # timestep=timestep - 1, + # sample=latents, + # generator=generator, + # ).prev_sample + + # return latents @torch.no_grad() def __call__( @@ -522,7 +540,7 @@ def __call__( predicted_image_embeddings: torch.Tensor, text_encoder_hidden_states: torch.Tensor, inference_steps: dict = None, - guidance_scale: float = 0.0, + guidance_scale: float = 3.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pt", # pt only @@ -546,6 +564,7 @@ def __call__( total_num_inference_steps = sum(inference_steps.values()) self.scheduler.set_timesteps(total_num_inference_steps, device=device) + prior_timesteps_tensor = self.scheduler.timesteps latents = self.prepare_latents( effnet_features_shape, @@ -555,20 +574,45 @@ def __call__( latents, self.scheduler, ) - # print(generator_timesteps_tensor) - t_start = 1.0 - for t_end, steps in inference_steps.items(): - steps = torch.linspace(t_start, t_end, steps + 1, dtype=dtype, device=device) - latents = self.inference_loop( - latents, - steps, - predicted_image_embeddings, - text_encoder_hidden_states, - do_classifier_free_guidance, - guidance_scale, - generator, + + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + ratio = (t / self.scheduler.config.num_train_timesteps).to(dtype) + predicted_image_embedding = self.generator( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents, + r=ratio.expand(latents.size(0) * 2) if do_classifier_free_guidance else ratio[None], + effnet=torch.cat([predicted_image_embeddings, torch.zeros_like(predicted_image_embeddings)]) + if do_classifier_free_guidance + else predicted_image_embeddings, + clip=text_encoder_hidden_states, ) - t_start = t_end + + if do_classifier_free_guidance: + predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + latents = self.scheduler.step( + predicted_image_embedding, + timestep=t, + sample=latents, + generator=generator, + ).prev_sample + + # # print(generator_timesteps_tensor) + # t_start = 1.0 + # for t_end, steps in inference_steps.items(): + # steps = torch.linspace(t_start, t_end, steps + 1, dtype=dtype, device=device) + # latents = self.inference_loop( + # latents, + # steps, + # predicted_image_embeddings, + # text_encoder_hidden_states, + # do_classifier_free_guidance, + # guidance_scale, + # generator, + # ) + # t_start = t_end images = self.vqgan.decode(latents).sample From fe972d693ac2178d6053bab9713aef21cc0a0210 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Jul 2023 11:41:36 +0200 Subject: [PATCH 044/181] Update src/diffusers/models/resnet.py Co-authored-by: Patrick von Platen --- src/diffusers/models/resnet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 1863a75957a8..21c27c8c4a35 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -682,9 +682,9 @@ def __init__(self, dim): self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) def forward(self, inputs): - Gx = torch.norm(inputs, p=2, dim=(1, 2), keepdim=True) - Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) - return self.gamma * (inputs * Nx) + self.beta + inputs + gx = torch.norm(inputs, p=2, dim=(1, 2), keepdim=True) + nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (inputs * nx) + self.beta + inputs class GlobalResponseResidualBlock(nn.Module): From db5dd654dcc2b0125972e87cd0f50aa74a20d456 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Jul 2023 12:23:05 +0200 Subject: [PATCH 045/181] use torch.lerp --- .../wuerstchen/pipeline_wuerstchen.py | 38 +++++++++---------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 51cab1e8ff66..26d000e88ab0 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -75,10 +75,8 @@ class WuerstchenGeneratorPipelineOutput(BaseOutput): Output class for WuerstchenPriorPipeline. Args: - image_embeds (`torch.FloatTensor` or `np.ndarray`) - Prior image embeddings for text prompt - text_embeds (`torch.FloatTensor` or `np.ndarray`) - Clip text embeddings for unconditional tokens + images (`torch.FloatTensor` or `np.ndarray`) + Generated images for text prompt. """ images: Union[torch.FloatTensor, np.ndarray] @@ -320,20 +318,20 @@ def __call__( self.scheduler.set_timesteps(total_num_inference_steps, device=device) prior_timesteps_tensor = self.scheduler.timesteps - def seed_everything(seed: int): - import random, os - import numpy as np - import torch + # def seed_everything(seed: int): + # import random, os + # import numpy as np + # import torch - random.seed(seed) - os.environ["PYTHONHASHSEED"] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = True + # random.seed(seed) + # os.environ["PYTHONHASHSEED"] = str(seed) + # np.random.seed(seed) + # torch.manual_seed(seed) + # torch.cuda.manual_seed(seed) + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = True - seed_everything(42) + # seed_everything(42) latents = self.prepare_latents( effnet_features_shape, @@ -358,8 +356,8 @@ def seed_everything(seed: int): if do_classifier_free_guidance: predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) - predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( - predicted_image_embedding_text - predicted_image_embedding_uncond + predicted_image_embedding = torch.lerp( + predicted_image_embedding_uncond, predicted_image_embedding_text, guidance_scale ) latents = self.scheduler.step( @@ -588,8 +586,8 @@ def __call__( if do_classifier_free_guidance: predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) - predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( - predicted_image_embedding_text - predicted_image_embedding_uncond + predicted_image_embedding = torch.lerp( + predicted_image_embedding_uncond, predicted_image_embedding_text, guidance_scale ) latents = self.scheduler.step( From 44d6a0404bb99f2d4499a9b31db3dca21d788354 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Jul 2023 12:23:21 +0200 Subject: [PATCH 046/181] use warp-diffusion org --- scripts/convert_wuerstchen.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index ba74bf376b92..a2e1ca45722b 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -61,7 +61,7 @@ scheduler=scheduler, ) -prior_pipeline.save_pretrained("kashif/WuerstchenPriorPipeline") +prior_pipeline.save_pretrained("warp-diffusion/WuerstchenPriorPipeline") generator_pipeline = WuerstchenGeneratorPipeline( vqgan=vqmodel, @@ -69,4 +69,4 @@ efficient_net=efficient_net, scheduler=scheduler, ) -generator_pipeline.save_pretrained("kashif/WuerstchenGeneratorPipeline") +generator_pipeline.save_pretrained("warp-diffusion/WuerstchenGeneratorPipeline") From 180bbaec2f42e21ed17f5f854ae192afa51595c0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Jul 2023 13:18:53 +0200 Subject: [PATCH 047/181] clip_sample=False, --- scripts/convert_wuerstchen.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index a2e1ca45722b..cb7d395e9792 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -51,7 +51,10 @@ trained_betas = [] # scheduler -scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2") +scheduler = DDPMScheduler( + beta_schedule="squaredcos_cap_v2", + clip_sample=False, +) # Prior pipeline prior_pipeline = WuerstchenPriorPipeline( From 368e113235f61bb0e183d390c814f755d78d0959 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 19 Jul 2023 23:06:59 +0200 Subject: [PATCH 048/181] some refactoring --- scripts/convert_wuerstchen.py | 4 +- src/diffusers/__init__.py | 2 +- src/diffusers/models/__init__.py | 3 +- src/diffusers/models/vq_model.py | 2 +- src/diffusers/models/vq_paella.py | 136 +++++++ .../pipelines/wuerstchen/__init__.py | 6 +- src/diffusers/pipelines/wuerstchen/modules.py | 79 +--- .../wuerstchen/pipeline_wuerstchen.py | 354 +---------------- .../wuerstchen/pipeline_wuerstchen_prior.py | 361 ++++++++++++++++++ src/diffusers/pipelines/wuerstchen/prior.py | 136 +++++++ src/diffusers/utils/dummy_pt_objects.py | 2 +- .../wuerstchen/test_wuerstchen_prior.py | 187 +++++++++ 12 files changed, 836 insertions(+), 436 deletions(-) create mode 100644 src/diffusers/models/vq_paella.py create mode 100644 src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py create mode 100644 src/diffusers/pipelines/wuerstchen/prior.py create mode 100644 tests/pipelines/wuerstchen/test_wuerstchen_prior.py diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index cb7d395e9792..06ac7e771d1d 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -6,7 +6,7 @@ from diffusers import ( DDPMScheduler, - PaellaVQModel, + VQModelPaella, WuerstchenPriorPipeline, WuerstchenGeneratorPipeline, ) @@ -22,7 +22,7 @@ state_dict["vquantizer.embedding.weight"] = state_dict["vquantizer.codebook.weight"] state_dict.pop("vquantizer.codebook.weight") -vqmodel = PaellaVQModel( +vqmodel = VQModelPaella( codebook_size=paella_vqmodel.codebook_size, c_latent=paella_vqmodel.c_latent, ) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8b64183de9d4..8aa9d3eda3a5 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -39,7 +39,7 @@ AutoencoderKL, ControlNetModel, ModelMixin, - PaellaVQModel, + VQModelPaella, MultiAdapter, PriorTransformer, T2IAdapter, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index cc061efb6892..2d708e21f34a 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -28,7 +28,8 @@ from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel from .unet_3d_condition import UNet3DConditionModel - from .vq_model import PaellaVQModel, VQModel + from .vq_model import VQModel + from .vq_paella import VQModelPaella if is_flax_available(): from .controlnet_flax import FlaxControlNetModel diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index e888b8f50bc5..ab8fa008bd78 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -37,7 +37,7 @@ class VQEncoderOutput(BaseOutput): latents: torch.FloatTensor -class PaellaVQModel(ModelMixin, ConfigMixin): +class VQModelPaella(ModelMixin, ConfigMixin): r"""VQ-VAE model from Paella model. This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library diff --git a/src/diffusers/models/vq_paella.py b/src/diffusers/models/vq_paella.py new file mode 100644 index 000000000000..0dba4634fb4a --- /dev/null +++ b/src/diffusers/models/vq_paella.py @@ -0,0 +1,136 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from .modeling_utils import ModelMixin +from .resnet import MixingResidualBlock +from .vae import DecoderOutput, VectorQuantizer +from .vq_model import VQEncoderOutput + + +class VQModelPaella(ModelMixin, ConfigMixin): + r"""VQ-VAE model from Paella model. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + up_down_scale_factor (int, *optional*, defaults to 2): Up and Downscale factor of the input image. + levels (int, *optional*, defaults to 2): Number of levels in the model. + bottleneck_blocks (int, *optional*, defaults to 12): Number of bottleneck blocks in the model. + c_hidden (int, *optional*, defaults to 384): Number of hidden channels in the model. + c_latent (int, *optional*, defaults to 4): Number of latent channels in the model. + codebook_size (int, *optional*, defaults to 8192): Number of codebook vectors in the VQ-VAE. + scale_factor (float, *optional*, defaults to 0.3764): Scaling factor of the latent space. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_down_scale_factor: int = 2, + levels: int = 2, + bottleneck_blocks: int = 12, + c_hidden: int = 384, + c_latent: int = 4, + codebook_size: int = 8192, + scale_factor: float = 0.3764, + ): + super().__init__() + + c_levels = [c_hidden // (2**i) for i in reversed(range(levels))] + self.in_block = nn.Sequential( + nn.PixelUnshuffle(up_down_scale_factor), + nn.Conv2d(in_channels * up_down_scale_factor**2, c_levels[0], kernel_size=1), + ) + + down_blocks = [] + for i in range(levels): + if i > 0: + down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) + block = MixingResidualBlock(c_levels[i], c_levels[i] * 4) + down_blocks.append(block) + down_blocks.append( + nn.Sequential( + nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 + ) + ) + self.down_blocks = nn.Sequential(*down_blocks) + self.vquantizer = VectorQuantizer(codebook_size, vq_embed_dim=c_latent, legacy=False, beta=0.25) + + # Decoder blocks + up_blocks = [nn.Sequential(nn.Conv2d(c_latent, c_levels[-1], kernel_size=1))] + for i in range(levels): + for j in range(bottleneck_blocks if i == 0 else 1): + block = MixingResidualBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) + up_blocks.append(block) + if i < levels - 1: + up_blocks.append( + nn.ConvTranspose2d( + c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1 + ) + ) + self.up_blocks = nn.Sequential(*up_blocks) + self.out_block = nn.Sequential( + nn.Conv2d(c_levels[0], out_channels * up_down_scale_factor**2, kernel_size=1), + nn.PixelShuffle(up_down_scale_factor), + ) + + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: + h = self.in_block(x) + h = self.down_blocks(h) / self.config.scale_factor + + if not return_dict: + return (h,) + + return VQEncoderOutput(latents=h) + + def decode( + self, h: torch.FloatTensor, force_not_quantize: bool = True, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + if not force_not_quantize: + quant, _, _ = self.vquantizer(h * self.config.scale_factor) + else: + quant = h * self.config.scale_factor + + x = self.up_blocks(quant) + dec = self.out_block(x) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + h = self.encode(x).latents + dec = self.decode(h).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/src/diffusers/pipelines/wuerstchen/__init__.py b/src/diffusers/pipelines/wuerstchen/__init__.py index 0b312ce9a867..4cf1d2bcda88 100644 --- a/src/diffusers/pipelines/wuerstchen/__init__.py +++ b/src/diffusers/pipelines/wuerstchen/__init__.py @@ -2,5 +2,7 @@ if is_transformers_available() and is_torch_available(): - from .modules import Prior, DiffNeXt, EfficientNetEncoder - from .pipeline_wuerstchen import WuerstchenGeneratorPipeline, WuerstchenPriorPipeline + from .modules import DiffNeXt, EfficientNetEncoder + from .prior import Prior + from .pipeline_wuerstchen import WuerstchenGeneratorPipeline + from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline diff --git a/src/diffusers/pipelines/wuerstchen/modules.py b/src/diffusers/pipelines/wuerstchen/modules.py index df34ad0e7d4b..3472f3101689 100644 --- a/src/diffusers/pipelines/wuerstchen/modules.py +++ b/src/diffusers/pipelines/wuerstchen/modules.py @@ -3,7 +3,6 @@ import numpy as np import torch import torch.nn as nn -from torchvision.models import efficientnet_v2_l, efficientnet_v2_s from diffusers.configuration_utils import ConfigMixin, register_to_config @@ -114,6 +113,7 @@ class EfficientNetEncoder(ModelMixin, ConfigMixin): @register_to_config def __init__(self, c_latent=16, effnet="efficientnet_v2_s"): super().__init__() + from torchvision.models import efficientnet_v2_l, efficientnet_v2_s # can't use `torchvision` if effnet == "efficientnet_v2_s": self.backbone = efficientnet_v2_s(weights="DEFAULT").features.eval() else: @@ -128,79 +128,6 @@ def forward(self, x): return self.mapper(self.backbone(x)) -class Prior(ModelMixin, ConfigMixin): - @register_to_config - def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, latent_size=(12, 12), dropout=0.1): - super().__init__() - self.c_r = c_r - self.projection = nn.Conv2d(c_in, c, kernel_size=1) - self.cond_mapper = nn.Sequential( - nn.Linear(c_cond, c), - nn.LeakyReLU(0.2), - nn.Linear(c, c), - ) - - self.blocks = nn.ModuleList() - for _ in range(depth): - self.blocks.append(ResBlock(c, dropout=dropout)) - self.blocks.append(TimestepBlock(c, c_r)) - self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout)) - self.out = nn.Sequential( - LayerNorm2d(c, elementwise_affine=False, eps=1e-6), - nn.Conv2d(c, c_in * 2, kernel_size=1), - ) - - self.apply(self._init_weights) # General init - nn.init.normal_(self.projection.weight, std=0.02) # inputs - nn.init.normal_(self.cond_mapper[0].weight, std=0.02) # conditionings - nn.init.normal_(self.cond_mapper[-1].weight, std=0.02) # conditionings - nn.init.constant_(self.out[1].weight, 0) # outputs - - # blocks - for block in self.blocks: - if isinstance(block, ResBlock): - block.channelwise[-1].weight.data *= np.sqrt(1 / depth) - elif isinstance(block, TimestepBlock): - nn.init.constant_(block.mapper.weight, 0) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv2d, nn.Linear)): - torch.nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def gen_r_embedding(self, r, max_positions=10000): - r = r * max_positions - half_dim = self.c_r // 2 - emb = math.log(max_positions) / (half_dim - 1) - emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() - emb = r[:, None] * emb[None, :] - emb = torch.cat([emb.sin(), emb.cos()], dim=1) - if self.c_r % 2 == 1: # zero pad - emb = nn.functional.pad(emb, (0, 1), mode="constant") - return emb.to(dtype=r.dtype) - - def forward(self, x, r, c): - x_in = x - x = self.projection(x) - c_embed = self.cond_mapper(c) - r_embed = self.gen_r_embedding(r) - for block in self.blocks: - if isinstance(block, AttnBlock): - x = block(x, c_embed) - elif isinstance(block, TimestepBlock): - x = block(x, r_embed) - else: - x = block(x) - a, b = self.out(x).chunk(2, dim=1) - # denoised = a / (1-(1-b).pow(2)).sqrt() - return (x_in - a) / ((1 - b).abs() + 1e-5) - - def update_weights_ema(self, src_model, beta=0.999): - for self_params, src_params in zip(self.parameters(), src_model.parameters()): - self_params.data = self_params.data * beta + src_params.data * (1 - beta) - - class DiffNeXt(ModelMixin, ConfigMixin): @register_to_config def __init__( @@ -405,7 +332,3 @@ def forward(self, x, r, effnet, clip, x_cat=None, eps=1e-3, return_noise=True): return (x_in - a) / b else: return a, b - - def update_weights_ema(self, src_model, beta=0.999): - for self_params, src_params in zip(self.parameters(), src_model.parameters()): - self_params.data = self_params.data * beta + src_params.data * (1 - beta) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 26d000e88ab0..accec77345b6 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -15,17 +15,14 @@ from dataclasses import dataclass from typing import List, Optional, Union -from PIL import Image import numpy as np import torch -from PIL.Image import Image -from transformers import CLIPTextModel, CLIPTokenizer -from ...models import PaellaVQModel +from ...models import VQModelPaella from ...schedulers import DDPMScheduler from ...utils import BaseOutput, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline -from .modules import DiffNeXt, Prior, EfficientNetEncoder +from .modules import DiffNeXt, EfficientNetEncoder # from .diffuzz import Diffuzz @@ -53,22 +50,6 @@ default_inference_steps_b = {0.0: 30} -@dataclass -class WuerstchenPriorPipelineOutput(BaseOutput): - """ - Output class for WuerstchenPriorPipeline. - - Args: - image_embeds (`torch.FloatTensor` or `np.ndarray`) - Prior image embeddings for text prompt - text_embeds (`torch.FloatTensor` or `np.ndarray`) - Clip text embeddings for unconditional tokens - """ - - image_embeds: Union[torch.FloatTensor, np.ndarray] - text_embeds: Union[torch.FloatTensor, np.ndarray] - - @dataclass class WuerstchenGeneratorPipelineOutput(BaseOutput): """ @@ -82,315 +63,6 @@ class WuerstchenGeneratorPipelineOutput(BaseOutput): images: Union[torch.FloatTensor, np.ndarray] -class WuerstchenPriorPipeline(DiffusionPipeline): - """ - Pipeline for generating image prior for Wuerstchen. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - prior ([`Prior`]): - The canonical unCLIP prior to approximate the image embedding from the text embedding. - text_encoder ([`CLIPTextModelWithProjection`]): - Frozen text-encoder. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - scheduler ([`DDPMScheduler`]): - A scheduler to be used in combination with `prior` to generate image embedding. - """ - - def __init__( - self, - tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModel, - prior: Prior, - scheduler: DDPMScheduler, - ) -> None: - super().__init__() - self.multiple = 128 - self.register_modules( - tokenizer=tokenizer, - text_encoder=text_encoder, - prior=prior, - scheduler=scheduler, - ) - # self.diffuzz = Diffuzz(device="cuda") - self.register_to_config() - - def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") - latents = latents.to(device) - - latents = latents * scheduler.init_noise_sigma - return latents - - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - ): - batch_size = len(prompt) if isinstance(prompt, list) else 1 - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - - text_encoder_output = self.text_encoder(text_input_ids.to(device)) - - text_encoder_hidden_states = text_encoder_output.last_hidden_state - - text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - - if do_classifier_free_guidance: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) - - uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state - - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - - seq_len = uncond_text_encoder_hidden_states.shape[1] - uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) - uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) - # done duplicates - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_encoder_hidden_states = torch.cat([text_encoder_hidden_states, uncond_text_encoder_hidden_states]) - - return text_encoder_hidden_states - - @property - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"): - return self.device - for module in self.text_encoder.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - # @torch.no_grad() - # def inference_loop( - # self, latents, steps, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator - # ): - # print(steps) - # print(steps[:-1]) - # for i, t in enumerate(self.progress_bar(steps[:-1])): - # # print(torch.cat([latents] * 2).shape, latents.dtype) - # # print(t.expand(latents.size(0) * 2).shape, t.dtype) - # # print(text_encoder_hidden_states.shape, text_encoder_hidden_states.dtype) - # predicted_image_embedding = self.prior( - # torch.cat([latents] * 2) if do_classifier_free_guidance else latents, - # r=t.expand(latents.size(0) * 2) if do_classifier_free_guidance else t, - # c=text_encoder_hidden_states, - # ) - - # # print(t.expand(latents.size(0) * 2)) - # # print(i, predicted_image_embedding[0, 0, :4, :4]) - # # print(text_encoder_hidden_states[0, 4, :4]) - - # if do_classifier_free_guidance: - # predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) - # predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( - # predicted_image_embedding_text - predicted_image_embedding_uncond - # ) - # # print(t) - - # # latents = self.diffuzz.undiffuse(latents, t[None], steps[i + 1][None], predicted_image_embedding).to( - # # dtype=t.dtype - # # ) - # timestep = (t * 999).cpu().int() - # # print(timestep) - # latents = self.scheduler.step( - # predicted_image_embedding, - # timestep=timestep - 1, - # sample=latents, - # generator=generator, - # ).prev_sample - - # return latents - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]] = None, - height: int = 1024, - width: int = 1024, - inference_steps: dict = None, - guidance_scale: float = 3.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pt", # pt only - return_dict: bool = True, - ): - device = self._execution_device - - do_classifier_free_guidance = guidance_scale > 1.0 - - if inference_steps is None: - inference_steps = default_inference_steps_c - - if negative_prompt is None: - negative_prompt = "" - - if isinstance(prompt, str): - prompt = [prompt] - elif not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] - elif not isinstance(negative_prompt, list) and negative_prompt is not None: - raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") - - text_encoder_hidden_states = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt - ) - dtype = text_encoder_hidden_states.dtype - latent_height = 128 * (height // 128) // (1024 // 24) - latent_width = 128 * (width // 128) // (1024 // 24) - effnet_features_shape = (num_images_per_prompt, 16, latent_height, latent_width) - - total_num_inference_steps = sum(inference_steps.values()) - self.scheduler.set_timesteps(total_num_inference_steps, device=device) - prior_timesteps_tensor = self.scheduler.timesteps - - # def seed_everything(seed: int): - # import random, os - # import numpy as np - # import torch - - # random.seed(seed) - # os.environ["PYTHONHASHSEED"] = str(seed) - # np.random.seed(seed) - # torch.manual_seed(seed) - # torch.cuda.manual_seed(seed) - # torch.backends.cudnn.deterministic = True - # torch.backends.cudnn.benchmark = True - - # seed_everything(42) - - latents = self.prepare_latents( - effnet_features_shape, - dtype, - device, - generator, - latents, - self.scheduler, - ) - - # latents = torch.randn(effnet_features_shape, device=device) - # print(latents[0, 0, :4, :4]) - # latents = latents.to(dtype=dtype) - - for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): - ratio = (t / self.scheduler.config.num_train_timesteps).to(dtype) # between 0 and 1 - predicted_image_embedding = self.prior( - torch.cat([latents] * 2) if do_classifier_free_guidance else latents, - r=ratio.expand(latents.size(0) * 2) if do_classifier_free_guidance else ratio, - c=text_encoder_hidden_states, - ) - - if do_classifier_free_guidance: - predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) - predicted_image_embedding = torch.lerp( - predicted_image_embedding_uncond, predicted_image_embedding_text, guidance_scale - ) - - latents = self.scheduler.step( - predicted_image_embedding, - timestep=t, - sample=latents, - generator=generator, - ).prev_sample - - # t_start = 1.0 - # for t_end, steps in inference_steps.items(): - # steps = torch.linspace(t_start, t_end, steps + 1, dtype=dtype, device=device) - # latents = self.inference_loop( - # latents, steps, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator - # ) - # t_start = t_end - - # normalize the latents - latents = latents * 42.0 - 1.0 - - if output_type not in ["pt", "np"]: - raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") - - if output_type == "np": - latents = latents.cpu().numpy() - text_encoder_hidden_states = text_encoder_hidden_states.cpu().numpy() - - if not return_dict: - return (latents, text_encoder_hidden_states) - - return WuerstchenPriorPipelineOutput(latents, text_encoder_hidden_states) - - class WuerstchenGeneratorPipeline(DiffusionPipeline): """ Pipeline for generating images from the Wuerstchen model. @@ -401,7 +73,7 @@ class WuerstchenGeneratorPipeline(DiffusionPipeline): Args: generator ([`DiffNeXt`]): The DiffNeXt unet generator. - vqgan ([`PaellaVQModel`]): + vqgan ([`VQModelPaella`]): The VQGAN model. efficient_net ([`EfficientNetEncoder`]): The EfficientNet encoder. @@ -413,7 +85,7 @@ def __init__( self, generator: DiffNeXt, scheduler: DDPMScheduler, - vqgan: PaellaVQModel, + vqgan: VQModelPaella, efficient_net: EfficientNetEncoder, ) -> None: super().__init__() @@ -439,24 +111,6 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): latents = latents * scheduler.init_noise_sigma return latents - @property - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"): - return self.device - for module in self.text_encoder.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - def check_inputs( self, predicted_image_embeddings, text_encoder_hidden_states, do_classifier_free_guidance, device ): diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py new file mode 100644 index 000000000000..4665176419d0 --- /dev/null +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -0,0 +1,361 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...schedulers import DDPMScheduler +from ...utils import BaseOutput, logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .prior import Prior + +# from .diffuzz import Diffuzz + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import WuerstchenPriorPipeline, WuerstchenGeneratorPipeline + + >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained("kashif/wuerstchen-prior", torch_dtype=torch.float16).to("cuda") + >>> gen_pipe = WuerstchenGeneratorPipeline.from_pretrain("kashif/wuerstchen-gen", torch_dtype=torch.float16).to("cuda") + + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> prior_output = pipe(prompt) + >>> images = gen_pipe(prior_output.image_embeds, prior_output.text_embeds) + ``` +""" + + +default_inference_steps_c = {2 / 3: 20, 0.0: 10} +# default_inference_steps_c = {0.0: 60} +default_inference_steps_b = {0.0: 30} + + +@dataclass +class WuerstchenPriorPipelineOutput(BaseOutput): + """ + Output class for WuerstchenPriorPipeline. + + Args: + image_embeds (`torch.FloatTensor` or `np.ndarray`) + Prior image embeddings for text prompt + text_embeds (`torch.FloatTensor` or `np.ndarray`) + Clip text embeddings for unconditional tokens + """ + + image_embeds: Union[torch.FloatTensor, np.ndarray] + text_embeds: Union[torch.FloatTensor, np.ndarray] + + +class WuerstchenPriorPipeline(DiffusionPipeline): + """ + Pipeline for generating image prior for Wuerstchen. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`Prior`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`DDPMScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + """ + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + prior: Prior, + scheduler: DDPMScheduler, + ) -> None: + super().__init__() + self.multiple = 128 + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + prior=prior, + scheduler=scheduler, + ) + # self.diffuzz = Diffuzz(device="cuda") + self.register_to_config() + + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_encoder_hidden_states = torch.cat([text_encoder_hidden_states, uncond_text_encoder_hidden_states]) + + return text_encoder_hidden_states + + # @torch.no_grad() + # def inference_loop( + # self, latents, steps, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator + # ): + # print(steps) + # print(steps[:-1]) + # for i, t in enumerate(self.progress_bar(steps[:-1])): + # # print(torch.cat([latents] * 2).shape, latents.dtype) + # # print(t.expand(latents.size(0) * 2).shape, t.dtype) + # # print(text_encoder_hidden_states.shape, text_encoder_hidden_states.dtype) + # predicted_image_embedding = self.prior( + # torch.cat([latents] * 2) if do_classifier_free_guidance else latents, + # r=t.expand(latents.size(0) * 2) if do_classifier_free_guidance else t, + # c=text_encoder_hidden_states, + # ) + + # # print(t.expand(latents.size(0) * 2)) + # # print(i, predicted_image_embedding[0, 0, :4, :4]) + # # print(text_encoder_hidden_states[0, 4, :4]) + + # if do_classifier_free_guidance: + # predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) + # predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( + # predicted_image_embedding_text - predicted_image_embedding_uncond + # ) + # # print(t) + + # # latents = self.diffuzz.undiffuse(latents, t[None], steps[i + 1][None], predicted_image_embedding).to( + # # dtype=t.dtype + # # ) + # timestep = (t * 999).cpu().int() + # # print(timestep) + # latents = self.scheduler.step( + # predicted_image_embedding, + # timestep=timestep - 1, + # sample=latents, + # generator=generator, + # ).prev_sample + + # return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: int = 1024, + width: int = 1024, + num_inference_steps: Optional[int] = None, # TODO(Kashif) - this should not stay None as a default & it should replace inference_steps + inference_steps: dict = None, + guidance_scale: float = 3.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pt", # pt only + return_dict: bool = True, + ): + device = self._execution_device + + do_classifier_free_guidance = guidance_scale > 1.0 + + if inference_steps is None: + inference_steps = default_inference_steps_c + + if negative_prompt is None: + negative_prompt = "" + + if isinstance(prompt, str): + prompt = [prompt] + elif not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif not isinstance(negative_prompt, list) and negative_prompt is not None: + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + text_encoder_hidden_states = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + dtype = text_encoder_hidden_states.dtype + latent_height = 128 * (height // 128) // (1024 // 24) + latent_width = 128 * (width // 128) // (1024 // 24) + num_channels = self.prior.config.c_in + effnet_features_shape = (num_images_per_prompt, num_channels, latent_height, latent_width) + + if num_inference_steps is None: + num_inference_steps = sum(inference_steps.values()) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + + # def seed_everything(seed: int): + # import random, os + # import numpy as np + # import torch + + # random.seed(seed) + # os.environ["PYTHONHASHSEED"] = str(seed) + # np.random.seed(seed) + # torch.manual_seed(seed) + # torch.cuda.manual_seed(seed) + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = True + + # seed_everything(42) + + latents = self.prepare_latents( + effnet_features_shape, + dtype, + device, + generator, + latents, + self.scheduler, + ) + + # latents = torch.randn(effnet_features_shape, device=device) + # print(latents[0, 0, :4, :4]) + # latents = latents.to(dtype=dtype) + + # for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + for t in self.progress_bar(self.scheduler.timesteps): + ratio = (t / self.scheduler.config.num_train_timesteps).to(dtype) # between 0 and 1 + predicted_image_embedding = self.prior( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents, + r=ratio.expand(latents.size(0) * 2) if do_classifier_free_guidance else ratio, + c=text_encoder_hidden_states, + ) + + if do_classifier_free_guidance: + predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) + predicted_image_embedding = torch.lerp( + predicted_image_embedding_uncond, predicted_image_embedding_text, guidance_scale + ) + + latents = self.scheduler.step( + predicted_image_embedding, + timestep=t, + sample=latents, + generator=generator, + ).prev_sample + + # t_start = 1.0 + # for t_end, steps in inference_steps.items(): + # steps = torch.linspace(t_start, t_end, steps + 1, dtype=dtype, device=device) + # latents = self.inference_loop( + # latents, steps, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator + # ) + # t_start = t_end + + # normalize the latents + latents = latents * 42.0 - 1.0 + + if output_type not in ["pt", "np"]: + raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") + + if output_type == "np": + latents = latents.cpu().numpy() + text_encoder_hidden_states = text_encoder_hidden_states.cpu().numpy() + + if not return_dict: + return (latents, text_encoder_hidden_states) + + return WuerstchenPriorPipelineOutput(latents, text_encoder_hidden_states) diff --git a/src/diffusers/pipelines/wuerstchen/prior.py b/src/diffusers/pipelines/wuerstchen/prior.py new file mode 100644 index 000000000000..f3dae497afdc --- /dev/null +++ b/src/diffusers/pipelines/wuerstchen/prior.py @@ -0,0 +1,136 @@ +import math + +import torch +import torch.nn as nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config + +from ...models.modeling_utils import ModelMixin + + +class LayerNorm2d(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + +class TimestepBlock(nn.Module): + def __init__(self, c, c_timestep): + super().__init__() + self.mapper = nn.Linear(c_timestep, c * 2) + + def forward(self, x, t): + a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1) + return x * (1 + a) + b + + +class Attention2D(nn.Module): + def __init__(self, c, nhead, dropout=0.0): + super().__init__() + self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) + + def forward(self, x, kv, self_attn=False): + orig_shape = x.shape + x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 + if self_attn: + kv = torch.cat([x, kv], dim=1) + x = self.attn(x, kv, kv, need_weights=False)[0] + x = x.permute(0, 2, 1).view(*orig_shape) + return x + + +class ResBlock(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): + super().__init__() + self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c) + ) + + def forward(self, x, x_skip=None): + x_res = x + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1) + x = self.channelwise(x).permute(0, 3, 1, 2) + return x + x_res + + +# from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 +class GlobalResponseNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +class AttnBlock(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention2D(c, nhead, dropout) + self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c)) + + def forward(self, x, kv): + kv = self.kv_mapper(kv) + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + return x + + +class Prior(ModelMixin, ConfigMixin): + @register_to_config + def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, latent_size=(12, 12), dropout=0.1): + super().__init__() + self.c_r = c_r + self.projection = nn.Conv2d(c_in, c, kernel_size=1) + self.cond_mapper = nn.Sequential( + nn.Linear(c_cond, c), + nn.LeakyReLU(0.2), + nn.Linear(c, c), + ) + + self.blocks = nn.ModuleList() + for _ in range(depth): + self.blocks.append(ResBlock(c, dropout=dropout)) + self.blocks.append(TimestepBlock(c, c_r)) + self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout)) + self.out = nn.Sequential( + LayerNorm2d(c, elementwise_affine=False, eps=1e-6), + nn.Conv2d(c, c_in * 2, kernel_size=1), + ) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode="constant") + return emb.to(dtype=r.dtype) + + def forward(self, x, r, c): + x_in = x + x = self.projection(x) + c_embed = self.cond_mapper(c) + r_embed = self.gen_r_embedding(r) + for block in self.blocks: + if isinstance(block, AttnBlock): + x = block(x, c_embed) + elif isinstance(block, TimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + a, b = self.out(x).chunk(2, dim=1) + # denoised = a / (1-(1-b).pow(2)).sqrt() + return (x_in - a) / ((1 - b).abs() + 1e-5) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 38b07c7f86d8..ef2c8a053850 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -197,7 +197,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class PaellaVQModel(metaclass=DummyObject): +class VQModelPaella(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py new file mode 100644 index 000000000000..c315f9fe63eb --- /dev/null +++ b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py @@ -0,0 +1,187 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from transformers import ( + CLIPTextConfig, + CLIPTextModel, + CLIPTokenizer, +) + +from diffusers import WuerstchenPriorPipeline, DDPMScheduler +from diffusers.pipelines.wuerstchen import Prior +from diffusers.utils import torch_device +from diffusers.utils.testing_utils import enable_full_determinism, skip_mps + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = WuerstchenPriorPipeline + params = ["prompt"] + batch_params = ["prompt", "negative_prompt"] + required_optional_params = [ + "num_images_per_prompt", + "generator", + "num_inference_steps", + "latents", + "negative_prompt", + "guidance_scale", + "output_type", + "return_dict", + ] + test_xformers_attention = False + + @property + def text_embedder_hidden_size(self): + return 32 + + @property + def time_input_dim(self): + return 32 + + @property + def block_out_channels_0(self): + return self.time_input_dim + + @property + def time_embed_dim(self): + return self.time_input_dim * 4 + + @property + def dummy_tokenizer(self): + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + return tokenizer + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=self.text_embedder_hidden_size, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModel(config) + + @property + def dummy_prior(self): + torch.manual_seed(0) + + model_kwargs = { + "c_in": 2, + "c": 8, + "depth": 2, + "c_cond": 37, + "c_r": 8, + "nhead": 2, + "latent_size": (2, 2), + } + + model = Prior(**model_kwargs) + return model + + def get_dummy_components(self): + prior = self.dummy_prior + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + + scheduler = DDPMScheduler() + + components = { + "prior": prior, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "scheduler": scheduler, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "horse", + "generator": generator, + "guidance_scale": 4.0, + "num_inference_steps": 2, + "output_type": "np", + } + return inputs + + def test_kandinsky_prior(self): + device = "cpu" + + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(device)) + image = output.image_embeds + + image_from_tuple = pipe( + **self.get_dummy_inputs(device), + return_dict=False, + )[0] + + image_slice = image[0, -10:] + image_from_tuple_slice = image_from_tuple[0, -10:] + + assert image.shape == (1, 32) + + expected_slice = np.array( + [-0.0532, 1.7120, 0.3656, -1.0852, -0.8946, -1.1756, 0.4348, 0.2482, 0.5146, -0.1156] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + @skip_mps + def test_inference_batch_single_identical(self): + test_max_difference = torch_device == "cpu" + relax_max_difference = True + test_mean_pixel_difference = False + + self._test_inference_batch_single_identical( + test_max_difference=test_max_difference, + relax_max_difference=relax_max_difference, + test_mean_pixel_difference=test_mean_pixel_difference, + ) + + @skip_mps + def test_attention_slicing_forward_pass(self): + test_max_difference = torch_device == "cpu" + test_mean_pixel_difference = False + + self._test_attention_slicing_forward_pass( + test_max_difference=test_max_difference, + test_mean_pixel_difference=test_mean_pixel_difference, + ) From 0c0bedc206a28485a83459cdaced968da34b5f39 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 20 Jul 2023 10:09:19 +0200 Subject: [PATCH 049/181] use model_v3_stage_c --- scripts/convert_wuerstchen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 06ac7e771d1d..88279cb8c7f0 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -43,7 +43,7 @@ generator.load_state_dict(state_dict["state_dict"]) # Prior -state_dict = torch.load(os.path.join(model_path, "model_v2_stage_c.pt"), map_location=device) +state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device) prior_model = Prior(c_in=16, c=1536, c_cond=1024, c_r=64, depth=32, nhead=24).to(device) prior_model.load_state_dict(state_dict["ema_state_dict"]) From 1247ae9fc800a27de6d1c16e73fec167e2173a0e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 20 Jul 2023 10:16:28 +0200 Subject: [PATCH 050/181] c_cond size --- scripts/convert_wuerstchen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 88279cb8c7f0..8d7ce0e695d0 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -44,7 +44,7 @@ # Prior state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device) -prior_model = Prior(c_in=16, c=1536, c_cond=1024, c_r=64, depth=32, nhead=24).to(device) +prior_model = Prior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device) prior_model.load_state_dict(state_dict["ema_state_dict"]) # Trained betas for scheduler via cosine From b0435b332c9971fa8b62b04e7f6932fd3c7553b6 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 20 Jul 2023 14:25:00 +0200 Subject: [PATCH 051/181] use clip-bigG --- scripts/convert_wuerstchen.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 8d7ce0e695d0..0dda2254a1ba 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -30,8 +30,8 @@ # TODO: test vqmodel outputs match paella_vqmodel outputs # Clip Text encoder and tokenizer -text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") -tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") +text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") +tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") # EfficientNet state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device) From 9bdc662af689f3a79d3aa711cc8bbbfbc2ff7b08 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 21 Jul 2023 11:36:38 +0200 Subject: [PATCH 052/181] allow stage b clip to be None --- src/diffusers/pipelines/wuerstchen/modules.py | 21 ++++++++++++------- .../wuerstchen/pipeline_wuerstchen.py | 4 ++-- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/modules.py b/src/diffusers/pipelines/wuerstchen/modules.py index 3472f3101689..715a0b0a221e 100644 --- a/src/diffusers/pipelines/wuerstchen/modules.py +++ b/src/diffusers/pipelines/wuerstchen/modules.py @@ -32,11 +32,13 @@ def __init__(self, c, nhead, dropout=0.0): super().__init__() self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) - def forward(self, x, kv, self_attn=False): + def forward(self, x, kv=None, self_attn=False): orig_shape = x.shape x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 - if self_attn: + if self_attn and kv is not None: kv = torch.cat([x, kv], dim=1) + elif kv is None: + kv = x x = self.attn(x, kv, kv, need_weights=False)[0] x = x.permute(0, 2, 1).view(*orig_shape) return x @@ -103,8 +105,9 @@ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): self.attention = Attention2D(c, nhead, dropout) self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c)) - def forward(self, x, kv): - kv = self.kv_mapper(kv) + def forward(self, x, kv=None): + if kv is not None: + kv = self.kv_mapper(kv) x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) return x @@ -114,6 +117,7 @@ class EfficientNetEncoder(ModelMixin, ConfigMixin): def __init__(self, c_latent=16, effnet="efficientnet_v2_s"): super().__init__() from torchvision.models import efficientnet_v2_l, efficientnet_v2_s # can't use `torchvision` + if effnet == "efficientnet_v2_s": self.backbone = efficientnet_v2_s(weights="DEFAULT").features.eval() else: @@ -262,7 +266,7 @@ def gen_c_embeddings(self, clip): clip = self.seq_norm(clip) return clip - def _down_encode(self, x, r_embed, effnet, clip): + def _down_encode(self, x, r_embed, effnet, clip=None): level_outputs = [] for i, down_block in enumerate(self.down_blocks): effnet_c = None @@ -286,7 +290,7 @@ def _down_encode(self, x, r_embed, effnet, clip): level_outputs.insert(0, x) return level_outputs - def _up_decode(self, level_outputs, r_embed, effnet, clip): + def _up_decode(self, level_outputs, r_embed, effnet, clip=None): x = level_outputs[0] for i, up_block in enumerate(self.up_blocks): effnet_c = None @@ -314,12 +318,13 @@ def _up_decode(self, level_outputs, r_embed, effnet, clip): x = block(x) return x - def forward(self, x, r, effnet, clip, x_cat=None, eps=1e-3, return_noise=True): + def forward(self, x, r, effnet, clip=None, x_cat=None, eps=1e-3, return_noise=True): if x_cat is not None: x = torch.cat([x, x_cat], dim=1) # Process the conditioning embeddings r_embed = self.gen_r_embedding(r) - clip = self.gen_c_embeddings(clip) + if clip is not None: + clip = self.gen_c_embeddings(clip) # Model Blocks x_in = x diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index accec77345b6..e9da5d3049e6 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -209,7 +209,7 @@ def __call__( predicted_image_embeddings, text_encoder_hidden_states, do_classifier_free_guidance, device ) - dtype = text_encoder_hidden_states.dtype + dtype = predicted_image_embeddings.dtype latent_height = int(predicted_image_embeddings.size(2) * (256 / 24)) latent_width = int(predicted_image_embeddings.size(3) * (256 / 24)) effnet_features_shape = (predicted_image_embeddings.size(0), 4, latent_height, latent_width) @@ -235,7 +235,7 @@ def __call__( effnet=torch.cat([predicted_image_embeddings, torch.zeros_like(predicted_image_embeddings)]) if do_classifier_free_guidance else predicted_image_embeddings, - clip=text_encoder_hidden_states, + clip=None, ) if do_classifier_free_guidance: From c05d6c51217fb4b30c9b4c766793769d0f6479ec Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 21 Jul 2023 14:29:59 +0200 Subject: [PATCH 053/181] add dummy --- src/diffusers/utils/dummy_pt_objects.py | 22 +++++++------- .../dummy_torch_and_transformers_objects.py | 30 +++++++++++++++++++ 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 3c78a10e4cb3..081438c062e4 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -62,7 +62,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class MultiAdapter(metaclass=DummyObject): +class VQModelPaella(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -77,7 +77,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class PriorTransformer(metaclass=DummyObject): +class MultiAdapter(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -92,7 +92,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class T2IAdapter(metaclass=DummyObject): +class PriorTransformer(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -107,7 +107,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class T5FilmDecoder(metaclass=DummyObject): +class T2IAdapter(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -122,7 +122,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class Transformer2DModel(metaclass=DummyObject): +class T5FilmDecoder(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -137,7 +137,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class UNet1DModel(metaclass=DummyObject): +class Transformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -152,7 +152,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class UNet2DConditionModel(metaclass=DummyObject): +class UNet1DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -167,7 +167,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class UNet2DModel(metaclass=DummyObject): +class UNet2DConditionModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -182,7 +182,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class UNet3DConditionModel(metaclass=DummyObject): +class UNet2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -197,7 +197,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class VQModel(metaclass=DummyObject): +class UNet3DConditionModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -212,7 +212,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class VQModelPaella(metaclass=DummyObject): +class VQModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 016760337c69..299744ed7d09 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -960,3 +960,33 @@ def from_config(cls, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) + + +class WuerstchenGeneratorPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class WuerstchenPriorPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) From 2d9d85d8e6622df91d0b6b0386f732aca89efe5e Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Sun, 23 Jul 2023 19:49:22 +0200 Subject: [PATCH 054/181] =?UTF-8?q?w=C3=BCrstchen=20scheduler?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 5 +- scripts/convert_wuerstchen.py | 7 +- scripts/vqgan.py | 144 ++++++++++ scripts/wuerstchen_pipeline_test.py | 87 +++++- src/diffusers/__init__.py | 1 + src/diffusers/pipelines/wuerstchen/diffuzz.py | 3 +- .../wuerstchen/pipeline_wuerstchen.py | 117 +++----- .../wuerstchen/pipeline_wuerstchen_prior.py | 49 +--- src/diffusers/schedulers/__init__.py | 1 + .../schedulers/scheduling_ddpm_wuerstchen.py | 249 ++++++++++++++++++ 10 files changed, 521 insertions(+), 142 deletions(-) create mode 100644 scripts/vqgan.py create mode 100644 src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py diff --git a/.gitignore b/.gitignore index 45602a1f547e..94835edf4bdd 100644 --- a/.gitignore +++ b/.gitignore @@ -173,4 +173,7 @@ tags # ruff .ruff_cache -wandb \ No newline at end of file +wandb +scripts/models/ +scripts/warp-diffusion/ +scripts/samples/ diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 0dda2254a1ba..8708ea7a2977 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -5,7 +5,7 @@ from vqgan import VQModel from diffusers import ( - DDPMScheduler, + DDPMWuerstchenScheduler, VQModelPaella, WuerstchenPriorPipeline, WuerstchenGeneratorPipeline, @@ -51,10 +51,7 @@ trained_betas = [] # scheduler -scheduler = DDPMScheduler( - beta_schedule="squaredcos_cap_v2", - clip_sample=False, -) +scheduler = DDPMWuerstchenScheduler() # Prior pipeline prior_pipeline = WuerstchenPriorPipeline( diff --git a/scripts/vqgan.py b/scripts/vqgan.py new file mode 100644 index 000000000000..935023e1ff6e --- /dev/null +++ b/scripts/vqgan.py @@ -0,0 +1,144 @@ +import torch +from torch import nn +import numpy as np +import math +from tqdm import tqdm +import time +from torchtools.nn import VectorQuantize + +class ResBlock(nn.Module): + def __init__(self, c, c_hidden): + super().__init__() + # depthwise/attention + self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.depthwise = nn.Sequential( + nn.ReplicationPad2d(1), + nn.Conv2d(c, c, kernel_size=3, groups=c) + ) + + # channelwise + self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c, c_hidden), + nn.GELU(), + nn.Linear(c_hidden, c), + ) + + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) + + # Init weights + def _basic_init(module): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + + def _norm(self, x, norm): + return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + def forward(self, x): + mods = self.gammas + + x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] + x = x + self.depthwise(x_temp) * mods[2] + + x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] + x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] + + return x + +class VQModel(nn.Module): + def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, scale_factor=0.3764): # 1.0 + super().__init__() + self.c_latent = c_latent + self.scale_factor = scale_factor + c_levels = [c_hidden//(2**i) for i in reversed(range(levels))] + + # Encoder blocks + self.in_block = nn.Sequential( + nn.PixelUnshuffle(2), + nn.Conv2d(3*4, c_levels[0], kernel_size=1) + ) + down_blocks = [] + for i in range(levels): + if i > 0: + down_blocks.append(nn.Conv2d(c_levels[i-1], c_levels[i], kernel_size=4, stride=2, padding=1)) + block = ResBlock(c_levels[i], c_levels[i]*4) + down_blocks.append(block) + down_blocks.append(nn.Sequential( + nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 + )) + self.down_blocks = nn.Sequential(*down_blocks) + self.down_blocks[0] + + self.codebook_size = codebook_size + self.vquantizer = VectorQuantize(c_latent, k=codebook_size) + + # Decoder blocks + up_blocks = [nn.Sequential( + nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) + )] + for i in range(levels): + for j in range(bottleneck_blocks if i == 0 else 1): + block = ResBlock(c_levels[levels-1-i], c_levels[levels-1-i]*4) + up_blocks.append(block) + if i < levels-1: + up_blocks.append(nn.ConvTranspose2d(c_levels[levels-1-i], c_levels[levels-2-i], kernel_size=4, stride=2, padding=1)) + self.up_blocks = nn.Sequential(*up_blocks) + self.out_block = nn.Sequential( + nn.Conv2d(c_levels[0], 3*4, kernel_size=1), + nn.PixelShuffle(2), + ) + + def encode(self, x): + x = self.in_block(x) + x = self.down_blocks(x) + qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) + return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 + + def decode(self, x): + x = x * self.scale_factor + x = self.up_blocks(x) + x = self.out_block(x) + return x + + def decode_indices(self, x): + x = self.vquantizer.idx2vq(x, dim=1) + x = self.up_blocks(x) + x = self.out_block(x) + return x + + def forward(self, x, quantize=False): + qe, x, _, vq_loss = self.encode(x, quantize) + x = self.decode(qe) + return x, vq_loss + +class Discriminator(nn.Module): + def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6): + super().__init__() + d = max(depth - 3, 3) + layers = [ + nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), + nn.LeakyReLU(0.2), + ] + for i in range(depth - 1): + c_in = c_hidden // (2 ** max((d - i), 0)) + c_out = c_hidden // (2 ** max((d - 1 - i), 0)) + layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) + layers.append(nn.InstanceNorm2d(c_out)) + layers.append(nn.LeakyReLU(0.2)) + self.encoder = nn.Sequential(*layers) + self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) + self.logits = nn.Sigmoid() + + def forward(self, x, cond=None): + x = self.encoder(x) + if cond is not None: + cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) + x = torch.cat([x, cond], dim=1) + x = self.shuffle(x) + x = self.logits(x) + return x \ No newline at end of file diff --git a/scripts/wuerstchen_pipeline_test.py b/scripts/wuerstchen_pipeline_test.py index c7a1c6a1679a..320b3c4920d9 100644 --- a/scripts/wuerstchen_pipeline_test.py +++ b/scripts/wuerstchen_pipeline_test.py @@ -1,11 +1,15 @@ import os import numpy as np import torch +import torchvision +import transformers from PIL import Image +from transformers import AutoTokenizer, CLIPTextModel from diffusers import WuerstchenPriorPipeline, WuerstchenGeneratorPipeline +transformers.utils.logging.set_verbosity_error() -def numpy_to_pil(images: np.ndarray) -> list[Image]: +def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: """ Convert a numpy image or a batch of images to a PIL image. """ @@ -16,25 +20,84 @@ def numpy_to_pil(images: np.ndarray) -> list[Image]: return pil_images +effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Resize(768, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), + torchvision.transforms.CenterCrop(768), + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) +]) -prior_pipeline = WuerstchenPriorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\kashif\\WuerstchenPriorPipeline", torch_dtype=torch.float16) -generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\kashif\\WuerstchenGeneratorPipeline", torch_dtype=torch.float16) +transforms = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(1024), + torchvision.transforms.RandomCrop(1024), +]) +device = "cuda" +dtype = torch.float16 +batch_size = 2 + +# generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\warp-diffusion\\WuerstchenGeneratorPipeline", torch_dtype=dtype) +# generator_pipeline = generator_pipeline.to("cuda") +# text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cuda") +# tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + +# image = Image.open("C:\\Users\\d6582\\Documents\\ml\\wand\\finetuning\\images\\fernando\\IMG_0352.JPG") +# image = effnet_preprocess(transforms(image).unsqueeze(0).expand(batch_size, -1, -1, -1)).to("cuda").to(dtype) +# print(image.shape) + +# caption = "princess | centered| key visual| intricate| highly detailed| breathtaking beauty| precise lineart| vibrant| comprehensive cinematic| Carne Griffiths| Conrad Roset" +# negative_prompt = "low resolution, low detail, bad quality, blurry" + +# clip_tokens = tokenizer([caption] * image.size(0), truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt").to("cuda") +# clip_text_embeddings = text_encoder(**clip_tokens).last_hidden_state.to(dtype) +# clip_tokens_uncond = tokenizer([negative_prompt] * image.size(0), truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt").to("cuda") +# clip_text_embeddings_uncond = text_encoder(**clip_tokens_uncond).last_hidden_state.to(dtype) + +# image_embeds = generator_pipeline.encode_image(image) +# generator_output = generator_pipeline(image_embeds, clip_text_embeddings, guidance_scale=0.0, output_type="np").images +# images = numpy_to_pil(generator_output) +# os.makedirs("samples", exist_ok=True) +# for i, image in enumerate(images): +# image.save(os.path.join("samples", caption.replace(" ", "_").replace("|", "") + f"_{i}.png")) + + +prior_pipeline = WuerstchenPriorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\warp-diffusion\\WuerstchenPriorPipeline", torch_dtype=dtype) +generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\warp-diffusion\\WuerstchenGeneratorPipeline", torch_dtype=dtype) prior_pipeline = prior_pipeline.to("cuda") generator_pipeline = generator_pipeline.to("cuda") +text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu") +tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") negative_prompt = "low resolution, low detail, bad quality, blurry" # negative_prompt = "" # caption = "Bee flying out of a glass jar in a green and red leafy basket, glass and lens flare, diffuse lighting elegant" # caption = "princess | centered| key visual| intricate| highly detailed| breathtaking beauty| precise lineart| vibrant| comprehensive cinematic| Carne Griffiths| Conrad Roset" -caption = input("Prompt please: ") -while caption != "q": - prior_output = prior_pipeline(caption, num_images_per_prompt=4, negative_prompt=negative_prompt) - generator_output = generator_pipeline(prior_output.image_embeds, prior_output.text_embeds, output_type="np").images - images = numpy_to_pil(generator_output) +caption = "An armchair in the shape of an avocado" +clip_tokens = tokenizer([caption] * batch_size, truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt") +clip_text_embeddings = text_encoder(**clip_tokens).last_hidden_state.to(dtype).to(device) +clip_tokens_uncond = tokenizer([negative_prompt] * batch_size, truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt") +clip_text_embeddings_uncond = text_encoder(**clip_tokens_uncond).last_hidden_state.to(dtype).to(device) + +prior_output = prior_pipeline(caption, guidance_scale=8.0, num_images_per_prompt=batch_size, negative_prompt=negative_prompt) +generator_output = generator_pipeline(prior_output.image_embeds, clip_text_embeddings, guidance_scale=0.0, output_type="np").images +images = numpy_to_pil(generator_output) +os.makedirs("samples", exist_ok=True) +for i, image in enumerate(images): + image.save(os.path.join("samples", caption.replace(" ", "_").replace("|", "") + f"_{i}.png")) + + + + +# caption = input("Prompt please: ") +# while caption != "q": +# prior_output = prior_pipeline(caption, num_images_per_prompt=4, negative_prompt=negative_prompt) +# generator_output = generator_pipeline(prior_output.image_embeds, prior_output.text_embeds, output_type="np").images +# images = numpy_to_pil(generator_output) - os.makedirs("samples", exist_ok=True) - for i, image in enumerate(images): - image.save(os.path.join("samples", caption.replace(" ", "_").replace("|", "") + f"_{i}.png")) +# os.makedirs("samples", exist_ok=True) +# for i, image in enumerate(images): +# image.save(os.path.join("samples", caption.replace(" ", "_").replace("|", "") + f"_{i}.png")) - caption = input("Prompt please: ") +# caption = input("Prompt please: ") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index c95421c37779..83fb2a04706d 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -83,6 +83,7 @@ DDIMParallelScheduler, DDIMScheduler, DDPMParallelScheduler, + DDPMWuerstchenScheduler, DDPMScheduler, DEISMultistepScheduler, DPMSolverMultistepInverseScheduler, diff --git a/src/diffusers/pipelines/wuerstchen/diffuzz.py b/src/diffusers/pipelines/wuerstchen/diffuzz.py index 25c3db27dfd2..79e1760c3fc6 100644 --- a/src/diffusers/pipelines/wuerstchen/diffuzz.py +++ b/src/diffusers/pipelines/wuerstchen/diffuzz.py @@ -80,8 +80,7 @@ def undiffuse(self, x, t, t_prev, noise, sampler=None): def sample(self, model, model_inputs, shape, mask=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, unconditional_inputs=None, sampler='ddpm', half=False): - r_range = torch.linspace(t_start, t_end, timesteps + 1)[:, None].expand(-1, shape[ - 0] if x_init is None else x_init.size(0)).to(self.device) + r_range = torch.linspace(t_start, t_end, timesteps + 1)[:, None].expand(-1, shape[0] if x_init is None else x_init.size(0)).to(self.device) if isinstance(sampler, str): if sampler in sampler_dict: sampler = sampler_dict[sampler](self) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index e9da5d3049e6..6eec8622a83d 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -19,7 +19,7 @@ import torch from ...models import VQModelPaella -from ...schedulers import DDPMScheduler +from ...schedulers import DDPMWuerstchenScheduler from ...utils import BaseOutput, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from .modules import DiffNeXt, EfficientNetEncoder @@ -45,9 +45,7 @@ """ -default_inference_steps_c = {2 / 3: 20, 0.0: 10} -# default_inference_steps_c = {0.0: 60} -default_inference_steps_b = {0.0: 30} +default_inference_steps_b = {0.0: 12} @dataclass @@ -84,7 +82,7 @@ class WuerstchenGeneratorPipeline(DiffusionPipeline): def __init__( self, generator: DiffNeXt, - scheduler: DDPMScheduler, + scheduler: DDPMWuerstchenScheduler, vqgan: VQModelPaella, efficient_net: EfficientNetEncoder, ) -> None: @@ -108,7 +106,6 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) - latents = latents * scheduler.init_noise_sigma return latents def check_inputs( @@ -140,57 +137,15 @@ def check_inputs( return predicted_image_embeddings, text_encoder_hidden_states - # @torch.no_grad() - # def inference_loop( - # self, - # latents, - # steps, - # predicted_effnet_latents, - # text_encoder_hidden_states, - # do_classifier_free_guidance, - # guidance_scale, - # generator, - # ): - # for i, t in enumerate(self.progress_bar(steps[:-1])): - # # print(torch.cat([latents] * 2).shape, latents.dtype, latents.device) - # # print(t.expand(latents.size(0) * 2).shape, t.dtype, t.device) - # # print(text_encoder_hidden_states.shape, text_encoder_hidden_states.dtype, text_encoder_hidden_states.device) - # # print(predicted_effnet_latents.shape, predicted_effnet_latents.dtype, predicted_effnet_latents.device) - # predicted_image_embedding = self.generator( - # torch.cat([latents] * 2) if do_classifier_free_guidance else latents, - # r=t.expand(latents.size(0) * 2) if do_classifier_free_guidance else t[None], - # effnet=torch.cat([predicted_effnet_latents, torch.zeros_like(predicted_effnet_latents)]) - # if do_classifier_free_guidance - # else predicted_effnet_latents, - # clip=text_encoder_hidden_states, - # ) - - # if do_classifier_free_guidance: - # predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) - # predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( - # predicted_image_embedding_text - predicted_image_embedding_uncond - # ) - # # print(t) - # # latents = self.diffuzz.undiffuse(latents, t[None], steps[i + 1][None], predicted_image_embedding).to( - # # dtype=t.dtype - # # ) - - # timestep = (t * 999).cpu().int() - # # print(timestep) - # latents = self.scheduler.step( - # predicted_image_embedding, - # timestep=timestep - 1, - # sample=latents, - # generator=generator, - # ).prev_sample - - # return latents + @torch.no_grad() + def encode_image(self, image): + return self.efficient_net(image) @torch.no_grad() def __call__( self, predicted_image_embeddings: torch.Tensor, - text_encoder_hidden_states: torch.Tensor, + text_encoder_hidden_states: torch.Tensor = None, inference_steps: dict = None, guidance_scale: float = 3.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -214,9 +169,8 @@ def __call__( latent_width = int(predicted_image_embeddings.size(3) * (256 / 24)) effnet_features_shape = (predicted_image_embeddings.size(0), 4, latent_height, latent_width) - total_num_inference_steps = sum(inference_steps.values()) - self.scheduler.set_timesteps(total_num_inference_steps, device=device) - prior_timesteps_tensor = self.scheduler.timesteps + self.scheduler.set_timesteps(inference_steps, device=device) + timesteps = self.scheduler.timesteps latents = self.prepare_latents( effnet_features_shape, @@ -226,47 +180,36 @@ def __call__( latents, self.scheduler, ) - - for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): - ratio = (t / self.scheduler.config.num_train_timesteps).to(dtype) - predicted_image_embedding = self.generator( + # from transformers import AutoTokenizer, CLIPTextModel + # text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to(device) + # tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + # clip_tokens = tokenizer([""] * latents.size(0), truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt").to(device) + # clip_text_embeddings = text_encoder(**clip_tokens).last_hidden_state.to(dtype) + + for t in self.progress_bar(timesteps[:-1]): + ratio = t.expand(latents.size(0)).to(dtype) + effnet=torch.cat([predicted_image_embeddings, torch.zeros_like(predicted_image_embeddings)]) if do_classifier_free_guidance else predicted_image_embeddings + predicted_latents = self.generator( torch.cat([latents] * 2) if do_classifier_free_guidance else latents, - r=ratio.expand(latents.size(0) * 2) if do_classifier_free_guidance else ratio[None], - effnet=torch.cat([predicted_image_embeddings, torch.zeros_like(predicted_image_embeddings)]) - if do_classifier_free_guidance - else predicted_image_embeddings, - clip=None, + r=torch.cat([ratio] * 2) if do_classifier_free_guidance else ratio, + effnet=effnet, + clip=torch.cat([text_encoder_hidden_states] * 2) if do_classifier_free_guidance else text_encoder_hidden_states, ) if do_classifier_free_guidance: - predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) - predicted_image_embedding = torch.lerp( - predicted_image_embedding_uncond, predicted_image_embedding_text, guidance_scale + predicted_latents_text, predicted_latents_uncond = predicted_latents.chunk(2) + predicted_latents = torch.lerp( + predicted_latents_uncond, predicted_latents_text, guidance_scale ) latents = self.scheduler.step( - predicted_image_embedding, - timestep=t, + model_output=predicted_latents, + timestep=ratio, sample=latents, generator=generator, - ).prev_sample - - # # print(generator_timesteps_tensor) - # t_start = 1.0 - # for t_end, steps in inference_steps.items(): - # steps = torch.linspace(t_start, t_end, steps + 1, dtype=dtype, device=device) - # latents = self.inference_loop( - # latents, - # steps, - # predicted_image_embeddings, - # text_encoder_hidden_states, - # do_classifier_free_guidance, - # guidance_scale, - # generator, - # ) - # t_start = t_end - - images = self.vqgan.decode(latents).sample + ).prediction + + images = self.vqgan.decode(latents).sample.clamp(0, 1) if output_type not in ["pt", "np"]: raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 4665176419d0..118093ec8890 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -19,7 +19,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer -from ...schedulers import DDPMScheduler +from ...schedulers import DDPMWuerstchenScheduler from ...utils import BaseOutput, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from .prior import Prior @@ -90,7 +90,7 @@ def __init__( tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, prior: Prior, - scheduler: DDPMScheduler, + scheduler: DDPMWuerstchenScheduler, ) -> None: super().__init__() self.multiple = 128 @@ -111,7 +111,6 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) - latents = latents * scheduler.init_noise_sigma return latents def _encode_prompt( @@ -243,9 +242,8 @@ def __call__( prompt: Union[str, List[str]] = None, height: int = 1024, width: int = 1024, - num_inference_steps: Optional[int] = None, # TODO(Kashif) - this should not stay None as a default & it should replace inference_steps inference_steps: dict = None, - guidance_scale: float = 3.0, + guidance_scale: float = 8.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -282,25 +280,8 @@ def __call__( num_channels = self.prior.config.c_in effnet_features_shape = (num_images_per_prompt, num_channels, latent_height, latent_width) - if num_inference_steps is None: - num_inference_steps = sum(inference_steps.values()) - - self.scheduler.set_timesteps(num_inference_steps, device=device) - - # def seed_everything(seed: int): - # import random, os - # import numpy as np - # import torch - - # random.seed(seed) - # os.environ["PYTHONHASHSEED"] = str(seed) - # np.random.seed(seed) - # torch.manual_seed(seed) - # torch.cuda.manual_seed(seed) - # torch.backends.cudnn.deterministic = True - # torch.backends.cudnn.benchmark = True - - # seed_everything(42) + self.scheduler.set_timesteps(inference_steps, device=device) + timesteps = self.scheduler.timesteps latents = self.prepare_latents( effnet_features_shape, @@ -311,16 +292,14 @@ def __call__( self.scheduler, ) - # latents = torch.randn(effnet_features_shape, device=device) - # print(latents[0, 0, :4, :4]) - # latents = latents.to(dtype=dtype) - - # for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): - for t in self.progress_bar(self.scheduler.timesteps): - ratio = (t / self.scheduler.config.num_train_timesteps).to(dtype) # between 0 and 1 + for t in self.progress_bar(timesteps[:-1]): + ratio = t.expand(latents.size(0)).to(dtype) + # print(torch.cat([latents] * 2).shape, latents.dtype) + # print(ratio, ratio.shape, ratio.dtype) + # print(text_encoder_hidden_states.shape, text_encoder_hidden_states.dtype) predicted_image_embedding = self.prior( torch.cat([latents] * 2) if do_classifier_free_guidance else latents, - r=ratio.expand(latents.size(0) * 2) if do_classifier_free_guidance else ratio, + r=torch.cat([ratio] * 2) if do_classifier_free_guidance else ratio, c=text_encoder_hidden_states, ) @@ -331,11 +310,11 @@ def __call__( ) latents = self.scheduler.step( - predicted_image_embedding, - timestep=t, + model_output=predicted_image_embedding, + timestep=ratio, sample=latents, generator=generator, - ).prev_sample + ).prediction # t_start = 1.0 # for t_end, steps in inference_steps.items(): diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 0a07ce4baed2..fda44eb032ab 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -53,6 +53,7 @@ from .scheduling_unipc_multistep import UniPCMultistepScheduler from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_vq_diffusion import VQDiffusionScheduler + from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler try: if not is_flax_available(): diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py new file mode 100644 index 000000000000..55e9aede7bdf --- /dev/null +++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py @@ -0,0 +1,249 @@ +# Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +class DDPMWuerstchenSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prediction: torch.FloatTensor + + +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and + Langevin dynamics sampling. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2006.11239 + + Args: + scaler (`float`): .... + s (`float`): .... + """ + + @register_to_config + def __init__( + self, + scaler: float = 1.0, + s: float = 0.008, + ): + self.scaler = scaler + self.s = torch.tensor([s]) + self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 + + def _alpha_cumprod(self, t, device): + if self.scaler > 1: + t = 1 - (1 - t) ** self.scaler + elif self.scaler < 1: + t = t ** self.scaler + alpha_cumprod = torch.cos((t + self.s.to(device)) / (1 + self.s.to(device)) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod.to(device) + return alpha_cumprod.clamp(0.0001, 0.9999) + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def set_timesteps( + self, + inference_steps: Optional[dict] = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`Optional[int]`): + the number of diffusion steps used when generating samples with a pre-trained model. If passed, then + `timesteps` must be `None`. + device (`str` or `torch.device`, optional): + the device to which the timesteps are moved to. {2 / 3: 20, 0.0: 10} + """ + timesteps = None + t_start = 1.0 + for t_end, steps in inference_steps.items(): + steps = torch.linspace(t_start, t_end, steps + 1, device=device) + t_start = t_end + if timesteps is None: + timesteps = steps + else: + timesteps = torch.cat([timesteps, steps[1:]]) + + self.timesteps = timesteps + # print(f"Timesteps: {self.timesteps}, Timesteps Shape: {timesteps.shape}") + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator=None, + return_dict: bool = True, + ) -> Union[DDPMWuerstchenSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + dtype = model_output.dtype + device = model_output.device + t = timestep + prev_t = self.previous_timestep(t) + + alpha_cumprod = self._alpha_cumprod(t, device).view(t.size(0), *[1 for _ in sample.shape[1:]]) + alpha_cumprod_prev = self._alpha_cumprod(prev_t, device).view(prev_t.size(0), *[1 for _ in sample.shape[1:]]) + alpha = (alpha_cumprod / alpha_cumprod_prev) + + mu = (1.0 / alpha).sqrt() * (sample - (1 - alpha) * model_output / (1 - alpha_cumprod).sqrt()) + std_noise = randn_tensor(mu.shape, generator=generator, device=model_output.device, dtype=model_output.dtype) + std = ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * std_noise + pred = mu + std * (prev_t != 0).float().view(prev_t.size(0), *[1 for _ in sample.shape[1:]]) + + if not return_dict: + return (pred.to(dtype),) + + return DDPMWuerstchenSchedulerOutput(prediction=pred.to(dtype)) + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps + + def previous_timestep(self, timestep): + # print(f"Timestep Shape: {timestep.shape}") + # print(timestep) + # print((self.timesteps == timestep[0]).nonzero()) + # index = (self.timesteps == timestep[0]).nonzero().item() + index = (self.timesteps - timestep[0]).abs().argmin().item() + # print(f"Found index at {index}") + # print(self.timesteps[index + 1]) + prev_t = self.timesteps[index + 1][None].expand(timestep.shape[0]) + # print(prev_t.shape, prev_t) + return prev_t From b9c3468a1ed6dca9ee8e44fd48bc7a47fb5555f2 Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Mon, 24 Jul 2023 12:34:41 +0200 Subject: [PATCH 055/181] minor changes --- scripts/convert_wuerstchen.py | 3 --- scripts/wuerstchen_pipeline_test.py | 6 ++++-- src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py | 1 + 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 8708ea7a2977..1b75d07a186f 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -47,9 +47,6 @@ prior_model = Prior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device) prior_model.load_state_dict(state_dict["ema_state_dict"]) -# Trained betas for scheduler via cosine -trained_betas = [] - # scheduler scheduler = DDPMWuerstchenScheduler() diff --git a/scripts/wuerstchen_pipeline_test.py b/scripts/wuerstchen_pipeline_test.py index 320b3c4920d9..9a920a8c3cb9 100644 --- a/scripts/wuerstchen_pipeline_test.py +++ b/scripts/wuerstchen_pipeline_test.py @@ -35,7 +35,7 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: ]) device = "cuda" dtype = torch.float16 -batch_size = 2 +batch_size = 4 # generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\warp-diffusion\\WuerstchenGeneratorPipeline", torch_dtype=dtype) # generator_pipeline = generator_pipeline.to("cuda") @@ -66,10 +66,12 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\warp-diffusion\\WuerstchenGeneratorPipeline", torch_dtype=dtype) prior_pipeline = prior_pipeline.to("cuda") generator_pipeline = generator_pipeline.to("cuda") +# generator_pipeline.vqgan.to(torch.float16) text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu") tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") -negative_prompt = "low resolution, low detail, bad quality, blurry" +# negative_prompt = "low resolution, low detail, bad quality, blurry" +negative_prompt = "bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated, drawing, helmet" # negative_prompt = "" # caption = "Bee flying out of a glass jar in a green and red leafy basket, glass and lens flare, diffuse lighting elegant" # caption = "princess | centered| key visual| intricate| highly detailed| breathtaking beauty| precise lineart| vibrant| comprehensive cinematic| Carne Griffiths| Conrad Roset" diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py index 55e9aede7bdf..5c5a217612d0 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py +++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py @@ -202,6 +202,7 @@ def step( mu = (1.0 / alpha).sqrt() * (sample - (1 - alpha) * model_output / (1 - alpha_cumprod).sqrt()) std_noise = randn_tensor(mu.shape, generator=generator, device=model_output.device, dtype=model_output.dtype) + # std_noise = torch.randn_like(mu) std = ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * std_noise pred = mu + std * (prev_t != 0).float().view(prev_t.size(0), *[1 for _ in sample.shape[1:]]) From a385e66c9c8cd61566c98ed9ccedd88c2be253d2 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 25 Jul 2023 11:08:26 +0200 Subject: [PATCH 056/181] set clip=None in the pipeline --- .../pipelines/wuerstchen/pipeline_wuerstchen.py | 12 +++++++----- .../schedulers/scheduling_ddpm_wuerstchen.py | 10 ++++++---- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 6eec8622a83d..00c4927ee3df 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -188,19 +188,21 @@ def __call__( for t in self.progress_bar(timesteps[:-1]): ratio = t.expand(latents.size(0)).to(dtype) - effnet=torch.cat([predicted_image_embeddings, torch.zeros_like(predicted_image_embeddings)]) if do_classifier_free_guidance else predicted_image_embeddings + effnet = ( + torch.cat([predicted_image_embeddings, torch.zeros_like(predicted_image_embeddings)]) + if do_classifier_free_guidance + else predicted_image_embeddings + ) predicted_latents = self.generator( torch.cat([latents] * 2) if do_classifier_free_guidance else latents, r=torch.cat([ratio] * 2) if do_classifier_free_guidance else ratio, effnet=effnet, - clip=torch.cat([text_encoder_hidden_states] * 2) if do_classifier_free_guidance else text_encoder_hidden_states, + clip=None, # torch.cat([text_encoder_hidden_states] * 2) if do_classifier_free_guidance else text_encoder_hidden_states, ) if do_classifier_free_guidance: predicted_latents_text, predicted_latents_uncond = predicted_latents.chunk(2) - predicted_latents = torch.lerp( - predicted_latents_uncond, predicted_latents_text, guidance_scale - ) + predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, guidance_scale) latents = self.scheduler.step( model_output=predicted_latents, diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py index 5c5a217612d0..9517ee02391e 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py +++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py @@ -118,8 +118,10 @@ def _alpha_cumprod(self, t, device): if self.scaler > 1: t = 1 - (1 - t) ** self.scaler elif self.scaler < 1: - t = t ** self.scaler - alpha_cumprod = torch.cos((t + self.s.to(device)) / (1 + self.s.to(device)) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod.to(device) + t = t**self.scaler + alpha_cumprod = torch.cos( + (t + self.s.to(device)) / (1 + self.s.to(device)) * torch.pi * 0.5 + ) ** 2 / self._init_alpha_cumprod.to(device) return alpha_cumprod.clamp(0.0001, 0.9999) def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: @@ -198,12 +200,12 @@ def step( alpha_cumprod = self._alpha_cumprod(t, device).view(t.size(0), *[1 for _ in sample.shape[1:]]) alpha_cumprod_prev = self._alpha_cumprod(prev_t, device).view(prev_t.size(0), *[1 for _ in sample.shape[1:]]) - alpha = (alpha_cumprod / alpha_cumprod_prev) + alpha = alpha_cumprod / alpha_cumprod_prev mu = (1.0 / alpha).sqrt() * (sample - (1 - alpha) * model_output / (1 - alpha_cumprod).sqrt()) std_noise = randn_tensor(mu.shape, generator=generator, device=model_output.device, dtype=model_output.dtype) # std_noise = torch.randn_like(mu) - std = ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * std_noise + std = ((1 - alpha) * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod)).sqrt() * std_noise pred = mu + std * (prev_t != 0).float().view(prev_t.size(0), *[1 for _ in sample.shape[1:]]) if not return_dict: diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 081438c062e4..a9d462b4c3e8 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -540,6 +540,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DDPMWuerstchenScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DDPMScheduler(metaclass=DummyObject): _backends = ["torch"] From 0db4f1910c89266a8620bc40c3dd078067e4639d Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Fri, 4 Aug 2023 18:56:42 +0200 Subject: [PATCH 057/181] fix attention mask --- .gitignore | 2 + scripts/diffuzz.py | 126 ++++++++++++++++++ scripts/the_ulitmate_test.py | 41 ++++++ scripts/wuerstchen_pipeline_test.py | 37 +++-- .../wuerstchen/pipeline_wuerstchen.py | 1 + .../wuerstchen/pipeline_wuerstchen_prior.py | 67 ++-------- .../schedulers/scheduling_ddpm_wuerstchen.py | 9 ++ 7 files changed, 220 insertions(+), 63 deletions(-) create mode 100644 scripts/diffuzz.py create mode 100644 scripts/the_ulitmate_test.py diff --git a/.gitignore b/.gitignore index 94835edf4bdd..5b7f6b1ec1de 100644 --- a/.gitignore +++ b/.gitignore @@ -177,3 +177,5 @@ wandb scripts/models/ scripts/warp-diffusion/ scripts/samples/ +scripts/text_encoding_colab.pt +scripts/text_encoding.pt diff --git a/scripts/diffuzz.py b/scripts/diffuzz.py new file mode 100644 index 000000000000..9f80a406b952 --- /dev/null +++ b/scripts/diffuzz.py @@ -0,0 +1,126 @@ +import torch + + +# Samplers -------------------------------------------------------------------- +class SimpleSampler(): + def __init__(self, diffuzz): + self.current_step = -1 + self.diffuzz = diffuzz + + def __call__(self, *args, **kwargs): + self.current_step += 1 + return self.step(*args, **kwargs) + + def init_x(self, shape): + return torch.randn(*shape, device=self.diffuzz.device) + + def step(self, x, t, t_prev, noise): + raise NotImplementedError("You should override the 'apply' function.") + + +class DDPMSampler(SimpleSampler): + def step(self, x, t, t_prev, noise): + alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) + alpha_cumprod_prev = self.diffuzz._alpha_cumprod(t_prev).view(t_prev.size(0), *[1 for _ in x.shape[1:]]) + alpha = (alpha_cumprod / alpha_cumprod_prev) + print(f"diffuzz: {alpha}") + + mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt()) + print(f"diffuzz: {mu.mean()}") + torch.manual_seed(0) + std = ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * torch.randn_like(mu) + print(f"diffuzz: {std.mean()}") + return mu + std * (t_prev != 0).float().view(t_prev.size(0), *[1 for _ in x.shape[1:]]) + + +class DDIMSampler(SimpleSampler): + def step(self, x, t, t_prev, noise): + alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) + alpha_cumprod_prev = self.diffuzz._alpha_cumprod(t_prev).view(t_prev.size(0), *[1 for _ in x.shape[1:]]) + + x0 = (x - (1 - alpha_cumprod).sqrt() * noise) / (alpha_cumprod).sqrt() + dp_xt = (1 - alpha_cumprod_prev).sqrt() + return (alpha_cumprod_prev).sqrt() * x0 + dp_xt * noise + + +sampler_dict = { + 'ddpm': DDPMSampler, + 'ddim': DDIMSampler, +} + + +# Custom simplified foward/backward diffusion (cosine schedule) +class Diffuzz(): + def __init__(self, s=0.008, device="cpu", cache_steps=None, scaler=1): + self.device = device + self.s = torch.tensor([s]).to(device) + self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 + self.scaler = scaler + self.cached_steps = None + if cache_steps is not None: + self.cached_steps = self._alpha_cumprod(torch.linspace(0, 1, cache_steps, device=device)) + + def _alpha_cumprod(self, t): + if self.cached_steps is None: + if self.scaler > 1: + t = 1 - (1 - t) ** self.scaler + elif self.scaler < 1: + t = t ** self.scaler + alpha_cumprod = torch.cos((t + self.s) / (1 + self.s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod + return alpha_cumprod.clamp(0.0001, 0.9999) + else: + return self.cached_steps[t.mul(len(self.cached_steps) - 1).long()] + + def diffuse(self, x, t, noise=None): # t -> [0, 1] + if noise is None: + noise = torch.randn_like(x) + alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) + return alpha_cumprod.sqrt() * x + (1 - alpha_cumprod).sqrt() * noise, noise + + def undiffuse(self, x, t, t_prev, noise, sampler=None): + if sampler is None: + sampler = DDPMSampler(self) + return sampler(x, t, t_prev, noise) + + def sample(self, model, model_inputs, shape, mask=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, + unconditional_inputs=None, sampler='ddpm', half=False): + r_range = torch.linspace(t_start, t_end, timesteps + 1)[:, None].expand(-1, shape[0] if x_init is None else x_init.size(0)).to(self.device) + if isinstance(sampler, str): + if sampler in sampler_dict: + sampler = sampler_dict[sampler](self) + else: + raise ValueError( + f"If sampler is a string it must be one of the supported samplers: {list(sampler_dict.keys())}") + elif issubclass(sampler, SimpleSampler): + sampler = sampler(self) + else: + raise ValueError("Sampler should be either a string or a SimpleSampler object.") + preds = [] + x = sampler.init_x(shape) if x_init is None or mask is not None else x_init.clone() + if half: + r_range = r_range.half() + x = x.half() + if cfg is not None: + if unconditional_inputs is None: + unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()} + model_inputs = {k: torch.cat([v, v_u]) for (k, v), (k_u, v_u) in + zip(model_inputs.items(), unconditional_inputs.items())} + for i in range(0, timesteps): + if mask is not None and x_init is not None: + x_renoised, _ = self.diffuse(x_init, r_range[i]) + x = x * mask + x_renoised * (1 - mask) + + if cfg is not None: + pred_noise, pred_noise_unconditional = model(torch.cat([x] * 2), torch.cat([r_range[i]] * 2), + **model_inputs).chunk(2) + pred_noise = torch.lerp(pred_noise_unconditional, pred_noise, cfg) + else: + pred_noise = model(x, r_range[i], **model_inputs) + + x = self.undiffuse(x, r_range[i], r_range[i + 1], pred_noise, sampler=sampler) + preds.append(x) + return preds + + def p2_weight(self, t, k=1.0, gamma=1.0): + alpha_cumprod = self._alpha_cumprod(t) + return (k + alpha_cumprod / (1 - alpha_cumprod)) ** -gamma \ No newline at end of file diff --git a/scripts/the_ulitmate_test.py b/scripts/the_ulitmate_test.py new file mode 100644 index 000000000000..90394e59502d --- /dev/null +++ b/scripts/the_ulitmate_test.py @@ -0,0 +1,41 @@ +# import torch +# from diffuzz import Diffuzz +# from diffusers import DDPMWuerstchenScheduler + +# torch.manual_seed(42) +# scheduler = DDPMWuerstchenScheduler() +# scheduler.set_timesteps({0.0: 30}) +# diffuzz = Diffuzz() + +# shape = (1, 16, 24, 24) +# x = torch.randn(shape) +# noise = torch.randn(shape) +# t = torch.rand(1) +# t_prev = t - 0.1 + +# output_diffuzz = diffuzz.undiffuse(x, t, t_prev, noise) +# output_scheduler = scheduler.step(noise, timestep=t, prev_t=t_prev, sample=x).prediction +# # scheduler.step(noise, timestep=t, sample=x) + +# print(output_diffuzz.mean()) +# print(output_scheduler.mean()) +# print(output_diffuzz.shape) +# print(output_scheduler.shape) + +from transformers import AutoTokenizer, CLIPTextModel + +device = "cuda" + +def embed_clip(caption, negative_caption="", batch_size=4, device="cuda"): + clip_tokens = clip_tokenizer([caption] * batch_size, truncation=True, padding="max_length", max_length=clip_tokenizer.model_max_length, return_tensors="pt").to(device) + clip_text_embeddings = clip_model(**clip_tokens).last_hidden_state + return clip_text_embeddings + +clip_model = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k").to(device).eval().requires_grad_(False) +clip_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + +caption = "An armchair in the shape of an avocado" + +emb = embed_clip(caption) + +print(emb) \ No newline at end of file diff --git a/scripts/wuerstchen_pipeline_test.py b/scripts/wuerstchen_pipeline_test.py index 9a920a8c3cb9..4c2b54589497 100644 --- a/scripts/wuerstchen_pipeline_test.py +++ b/scripts/wuerstchen_pipeline_test.py @@ -61,28 +61,45 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: # for i, image in enumerate(images): # image.save(os.path.join("samples", caption.replace(" ", "_").replace("|", "") + f"_{i}.png")) +torch.manual_seed(42) + +prior_pipeline = WuerstchenPriorPipeline.from_pretrained("warp-diffusion/WuerstchenPriorPipeline", torch_dtype=dtype) + +# from diffusers import DDPMScheduler +# noise_scheduler = DDPMScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler") +# prior_pipeline.scheduler = noise_scheduler + +generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained("warp-diffusion/WuerstchenGeneratorPipeline", torch_dtype=dtype) -prior_pipeline = WuerstchenPriorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\warp-diffusion\\WuerstchenPriorPipeline", torch_dtype=dtype) -generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\warp-diffusion\\WuerstchenGeneratorPipeline", torch_dtype=dtype) prior_pipeline = prior_pipeline.to("cuda") generator_pipeline = generator_pipeline.to("cuda") + +def embed_clip(clip_model, clip_tokenizer, caption, negative_caption="", batch_size=4, device="cuda"): + clip_tokens = clip_tokenizer([caption] * batch_size, truncation=True, padding="max_length", max_length=clip_tokenizer.model_max_length, return_tensors="pt").to(device) + clip_text_embeddings = clip_model(**clip_tokens).last_hidden_state + + clip_tokens_uncond = clip_tokenizer([negative_caption] * batch_size, truncation=True, padding="max_length", max_length=clip_tokenizer.model_max_length, return_tensors="pt").to(device) + clip_text_embeddings_uncond = clip_model(**clip_tokens_uncond).last_hidden_state + return clip_text_embeddings, clip_text_embeddings_uncond + # generator_pipeline.vqgan.to(torch.float16) -text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu") -tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") +# clip_model_G = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k").to("cpu") +# clip_tokenizer_G = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") +clip_model_H = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu") +clip_tokenizer_H = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") # negative_prompt = "low resolution, low detail, bad quality, blurry" negative_prompt = "bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated, drawing, helmet" # negative_prompt = "" # caption = "Bee flying out of a glass jar in a green and red leafy basket, glass and lens flare, diffuse lighting elegant" # caption = "princess | centered| key visual| intricate| highly detailed| breathtaking beauty| precise lineart| vibrant| comprehensive cinematic| Carne Griffiths| Conrad Roset" -caption = "An armchair in the shape of an avocado" -clip_tokens = tokenizer([caption] * batch_size, truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt") -clip_text_embeddings = text_encoder(**clip_tokens).last_hidden_state.to(dtype).to(device) -clip_tokens_uncond = tokenizer([negative_prompt] * batch_size, truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt") -clip_text_embeddings_uncond = text_encoder(**clip_tokens_uncond).last_hidden_state.to(dtype).to(device) +caption = "An astronaut riding a horse" +# clip_text_embeddings, clip_text_embeddings_uncond = embed_clip(clip_model_G, clip_tokenizer_G, caption, negative_prompt, batch_size, "cpu") +# embeds = torch.cat([clip_text_embeddings, clip_text_embeddings_uncond]).to(device).to(dtype) prior_output = prior_pipeline(caption, guidance_scale=8.0, num_images_per_prompt=batch_size, negative_prompt=negative_prompt) -generator_output = generator_pipeline(prior_output.image_embeds, clip_text_embeddings, guidance_scale=0.0, output_type="np").images +clip_text_embeddings, clip_text_embeddings_uncond = embed_clip(clip_model_H, clip_tokenizer_H, caption, negative_prompt, batch_size, "cpu") +generator_output = generator_pipeline(prior_output.image_embeds, clip_text_embeddings.to(device).to(dtype), guidance_scale=0.0, output_type="np").images images = numpy_to_pil(generator_output) os.makedirs("samples", exist_ok=True) for i, image in enumerate(images): diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 6eec8622a83d..84de6275e0da 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -193,6 +193,7 @@ def __call__( torch.cat([latents] * 2) if do_classifier_free_guidance else latents, r=torch.cat([ratio] * 2) if do_classifier_free_guidance else ratio, effnet=effnet, + # clip=torch.randn(latents.size(0)*2 if do_classifier_free_guidance else latents.size(0), 77, 1024).to(device).to(dtype) clip=torch.cat([text_encoder_hidden_states] * 2) if do_classifier_free_guidance else text_encoder_hidden_states, ) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 118093ec8890..affe5af3b055 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -129,20 +129,20 @@ def _encode_prompt( max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids + ).to(device) + # text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + # untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + # if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + # removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + # logger.warning( + # "The following part of your input was truncated because CLIP can only handle sequences up to" + # f" {self.tokenizer.model_max_length} tokens: {removed_text}" + # ) + # text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_encoder_output = self.text_encoder(text_input_ids.to(device)) + text_encoder_output = self.text_encoder(**text_inputs) text_encoder_hidden_states = text_encoder_output.last_hidden_state @@ -174,8 +174,8 @@ def _encode_prompt( max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", - ) - negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + ).to(device) + negative_prompt_embeds_text_encoder_output = self.text_encoder(**uncond_input) uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state @@ -195,46 +195,6 @@ def _encode_prompt( return text_encoder_hidden_states - # @torch.no_grad() - # def inference_loop( - # self, latents, steps, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator - # ): - # print(steps) - # print(steps[:-1]) - # for i, t in enumerate(self.progress_bar(steps[:-1])): - # # print(torch.cat([latents] * 2).shape, latents.dtype) - # # print(t.expand(latents.size(0) * 2).shape, t.dtype) - # # print(text_encoder_hidden_states.shape, text_encoder_hidden_states.dtype) - # predicted_image_embedding = self.prior( - # torch.cat([latents] * 2) if do_classifier_free_guidance else latents, - # r=t.expand(latents.size(0) * 2) if do_classifier_free_guidance else t, - # c=text_encoder_hidden_states, - # ) - - # # print(t.expand(latents.size(0) * 2)) - # # print(i, predicted_image_embedding[0, 0, :4, :4]) - # # print(text_encoder_hidden_states[0, 4, :4]) - - # if do_classifier_free_guidance: - # predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) - # predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( - # predicted_image_embedding_text - predicted_image_embedding_uncond - # ) - # # print(t) - - # # latents = self.diffuzz.undiffuse(latents, t[None], steps[i + 1][None], predicted_image_embedding).to( - # # dtype=t.dtype - # # ) - # timestep = (t * 999).cpu().int() - # # print(timestep) - # latents = self.scheduler.step( - # predicted_image_embedding, - # timestep=timestep - 1, - # sample=latents, - # generator=generator, - # ).prev_sample - - # return latents @torch.no_grad() def __call__( @@ -274,6 +234,7 @@ def __call__( text_encoder_hidden_states = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) + dtype = text_encoder_hidden_states.dtype latent_height = 128 * (height // 128) // (1024 // 24) latent_width = 128 * (width // 128) // (1024 // 24) diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py index 5c5a217612d0..e7ab067c71fd 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py +++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py @@ -171,6 +171,7 @@ def step( timestep: int, sample: torch.FloatTensor, generator=None, + # prev_t=None, return_dict: bool = True, ) -> Union[DDPMWuerstchenSchedulerOutput, Tuple]: """ @@ -194,16 +195,21 @@ def step( dtype = model_output.dtype device = model_output.device t = timestep + prev_t = self.previous_timestep(t) alpha_cumprod = self._alpha_cumprod(t, device).view(t.size(0), *[1 for _ in sample.shape[1:]]) alpha_cumprod_prev = self._alpha_cumprod(prev_t, device).view(prev_t.size(0), *[1 for _ in sample.shape[1:]]) alpha = (alpha_cumprod / alpha_cumprod_prev) + # print(f"scheduler: {alpha}") mu = (1.0 / alpha).sqrt() * (sample - (1 - alpha) * model_output / (1 - alpha_cumprod).sqrt()) + # print(f"scheduler: {mu.mean()}") + # torch.manual_seed(0) std_noise = randn_tensor(mu.shape, generator=generator, device=model_output.device, dtype=model_output.dtype) # std_noise = torch.randn_like(mu) std = ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * std_noise + # print(f"scheduler: {std.mean()}") pred = mu + std * (prev_t != 0).float().view(prev_t.size(0), *[1 for _ in sample.shape[1:]]) if not return_dict: @@ -247,4 +253,7 @@ def previous_timestep(self, timestep): # print(self.timesteps[index + 1]) prev_t = self.timesteps[index + 1][None].expand(timestep.shape[0]) # print(prev_t.shape, prev_t) + # print(timestep) + # print(prev_t) + # print("======== ===================================================================") return prev_t From 3ae3ea462d5c2164433d62c38090f5d4f548c17c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 4 Aug 2023 21:15:36 +0200 Subject: [PATCH 058/181] add attention_masks to text_encoder --- scripts/wuerstchen_pipeline_test.py | 70 ++++++++++++------- .../wuerstchen/pipeline_wuerstchen_prior.py | 11 ++- 2 files changed, 54 insertions(+), 27 deletions(-) diff --git a/scripts/wuerstchen_pipeline_test.py b/scripts/wuerstchen_pipeline_test.py index 9a920a8c3cb9..8c7ee8de4ac3 100644 --- a/scripts/wuerstchen_pipeline_test.py +++ b/scripts/wuerstchen_pipeline_test.py @@ -1,11 +1,11 @@ import os import numpy as np import torch -import torchvision import transformers from PIL import Image from transformers import AutoTokenizer, CLIPTextModel from diffusers import WuerstchenPriorPipeline, WuerstchenGeneratorPipeline + transformers.utils.logging.set_verbosity_error() @@ -20,19 +20,24 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: return pil_images -effnet_preprocess = torchvision.transforms.Compose([ - torchvision.transforms.Resize(768, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), - torchvision.transforms.CenterCrop(768), - torchvision.transforms.Normalize( - mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) - ) -]) - -transforms = torchvision.transforms.Compose([ - torchvision.transforms.ToTensor(), - torchvision.transforms.Resize(1024), - torchvision.transforms.RandomCrop(1024), -]) + +# effnet_preprocess = torchvision.transforms.Compose( +# [ +# torchvision.transforms.Resize( +# 768, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True +# ), +# torchvision.transforms.CenterCrop(768), +# torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), +# ] +# ) + +# transforms = torchvision.transforms.Compose( +# [ +# torchvision.transforms.ToTensor(), +# torchvision.transforms.Resize(1024), +# torchvision.transforms.RandomCrop(1024), +# ] +# ) device = "cuda" dtype = torch.float16 batch_size = 4 @@ -62,8 +67,10 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: # image.save(os.path.join("samples", caption.replace(" ", "_").replace("|", "") + f"_{i}.png")) -prior_pipeline = WuerstchenPriorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\warp-diffusion\\WuerstchenPriorPipeline", torch_dtype=dtype) -generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\warp-diffusion\\WuerstchenGeneratorPipeline", torch_dtype=dtype) +prior_pipeline = WuerstchenPriorPipeline.from_pretrained("warp-diffusion/WuerstchenPriorPipeline", torch_dtype=dtype) +generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained( + "warp-diffusion/WuerstchenGeneratorPipeline", torch_dtype=dtype +) prior_pipeline = prior_pipeline.to("cuda") generator_pipeline = generator_pipeline.to("cuda") # generator_pipeline.vqgan.to(torch.float16) @@ -76,21 +83,35 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: # caption = "Bee flying out of a glass jar in a green and red leafy basket, glass and lens flare, diffuse lighting elegant" # caption = "princess | centered| key visual| intricate| highly detailed| breathtaking beauty| precise lineart| vibrant| comprehensive cinematic| Carne Griffiths| Conrad Roset" caption = "An armchair in the shape of an avocado" -clip_tokens = tokenizer([caption] * batch_size, truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt") +clip_tokens = tokenizer( + [caption] * batch_size, + truncation=True, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", +) clip_text_embeddings = text_encoder(**clip_tokens).last_hidden_state.to(dtype).to(device) -clip_tokens_uncond = tokenizer([negative_prompt] * batch_size, truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt") -clip_text_embeddings_uncond = text_encoder(**clip_tokens_uncond).last_hidden_state.to(dtype).to(device) - -prior_output = prior_pipeline(caption, guidance_scale=8.0, num_images_per_prompt=batch_size, negative_prompt=negative_prompt) -generator_output = generator_pipeline(prior_output.image_embeds, clip_text_embeddings, guidance_scale=0.0, output_type="np").images +# clip_tokens_uncond = tokenizer( +# [negative_prompt] * batch_size, +# truncation=True, +# padding="max_length", +# max_length=tokenizer.model_max_length, +# return_tensors="pt", +# ) +# clip_text_embeddings_uncond = text_encoder(**clip_tokens_uncond).last_hidden_state.to(dtype).to(device) + +prior_output = prior_pipeline( + caption, guidance_scale=8.0, num_images_per_prompt=batch_size, negative_prompt=negative_prompt +) +generator_output = generator_pipeline( + prior_output.image_embeds, clip_text_embeddings, guidance_scale=0.0, output_type="np" +).images images = numpy_to_pil(generator_output) os.makedirs("samples", exist_ok=True) for i, image in enumerate(images): image.save(os.path.join("samples", caption.replace(" ", "_").replace("|", "") + f"_{i}.png")) - - # caption = input("Prompt please: ") # while caption != "q": # prior_output = prior_pipeline(caption, num_images_per_prompt=4, negative_prompt=negative_prompt) @@ -102,4 +123,3 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: # image.save(os.path.join("samples", caption.replace(" ", "_").replace("|", "") + f"_{i}.png")) # caption = input("Prompt please: ") - diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 118093ec8890..35a6e13bdcb6 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -131,6 +131,7 @@ def _encode_prompt( return_tensors="pt", ) text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids @@ -141,8 +142,12 @@ def _encode_prompt( f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + attention_mask = attention_mask[:, : self.tokenizer.model_max_length] - text_encoder_output = self.text_encoder(text_input_ids.to(device)) + text_encoder_output = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask.to(device), + ) text_encoder_hidden_states = text_encoder_output.last_hidden_state @@ -175,7 +180,9 @@ def _encode_prompt( truncation=True, return_tensors="pt", ) - negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + negative_prompt_embeds_text_encoder_output = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device) + ) uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state From b3b2b601cc74e15532ffa5b8345754c2419ec626 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 4 Aug 2023 21:18:01 +0200 Subject: [PATCH 059/181] make fix-copies --- .../pipelines/wuerstchen/pipeline_wuerstchen_prior.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 35a6e13bdcb6..5d2a753254de 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -144,10 +144,7 @@ def _encode_prompt( text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] attention_mask = attention_mask[:, : self.tokenizer.model_max_length] - text_encoder_output = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask.to(device), - ) + text_encoder_output = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device)) text_encoder_hidden_states = text_encoder_output.last_hidden_state From 86630378c119aa8b54e5e4b1a4d1e19a0380e871 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 4 Aug 2023 21:37:43 +0200 Subject: [PATCH 060/181] add back clip --- .../wuerstchen/pipeline_wuerstchen.py | 4 +- .../wuerstchen/pipeline_wuerstchen_prior.py | 43 ------------------- 2 files changed, 3 insertions(+), 44 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 00c4927ee3df..2929569cc308 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -197,7 +197,9 @@ def __call__( torch.cat([latents] * 2) if do_classifier_free_guidance else latents, r=torch.cat([ratio] * 2) if do_classifier_free_guidance else ratio, effnet=effnet, - clip=None, # torch.cat([text_encoder_hidden_states] * 2) if do_classifier_free_guidance else text_encoder_hidden_states, + clip=torch.cat([text_encoder_hidden_states] * 2) + if do_classifier_free_guidance + else text_encoder_hidden_states, ) if do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 5d2a753254de..f811907ec149 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -145,9 +145,7 @@ def _encode_prompt( attention_mask = attention_mask[:, : self.tokenizer.model_max_length] text_encoder_output = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device)) - text_encoder_hidden_states = text_encoder_output.last_hidden_state - text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: @@ -199,47 +197,6 @@ def _encode_prompt( return text_encoder_hidden_states - # @torch.no_grad() - # def inference_loop( - # self, latents, steps, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator - # ): - # print(steps) - # print(steps[:-1]) - # for i, t in enumerate(self.progress_bar(steps[:-1])): - # # print(torch.cat([latents] * 2).shape, latents.dtype) - # # print(t.expand(latents.size(0) * 2).shape, t.dtype) - # # print(text_encoder_hidden_states.shape, text_encoder_hidden_states.dtype) - # predicted_image_embedding = self.prior( - # torch.cat([latents] * 2) if do_classifier_free_guidance else latents, - # r=t.expand(latents.size(0) * 2) if do_classifier_free_guidance else t, - # c=text_encoder_hidden_states, - # ) - - # # print(t.expand(latents.size(0) * 2)) - # # print(i, predicted_image_embedding[0, 0, :4, :4]) - # # print(text_encoder_hidden_states[0, 4, :4]) - - # if do_classifier_free_guidance: - # predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) - # predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( - # predicted_image_embedding_text - predicted_image_embedding_uncond - # ) - # # print(t) - - # # latents = self.diffuzz.undiffuse(latents, t[None], steps[i + 1][None], predicted_image_embedding).to( - # # dtype=t.dtype - # # ) - # timestep = (t * 999).cpu().int() - # # print(timestep) - # latents = self.scheduler.step( - # predicted_image_embedding, - # timestep=timestep - 1, - # sample=latents, - # generator=generator, - # ).prev_sample - - # return latents - @torch.no_grad() def __call__( self, From 9ec3f0133f85c93b0f69a601c1613046994a9800 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 5 Aug 2023 19:32:01 +0200 Subject: [PATCH 061/181] add text_encoder --- .../wuerstchen/pipeline_wuerstchen.py | 106 +++++++++++++++++- 1 file changed, 105 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 2929569cc308..cb07790094f2 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -81,6 +81,8 @@ class WuerstchenGeneratorPipeline(DiffusionPipeline): def __init__( self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, generator: DiffNeXt, scheduler: DDPMWuerstchenScheduler, vqgan: VQModelPaella, @@ -89,6 +91,8 @@ def __init__( super().__init__() self.multiple = 128 self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, generator=generator, scheduler=scheduler, vqgan=vqgan, @@ -108,6 +112,90 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): return latents + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + attention_mask = attention_mask[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device)) + text_encoder_hidden_states = text_encoder_output.last_hidden_state + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds_text_encoder_output = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device) + ) + + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_encoder_hidden_states = torch.cat([text_encoder_hidden_states, uncond_text_encoder_hidden_states]) + + return text_encoder_hidden_states + def check_inputs( self, predicted_image_embeddings, text_encoder_hidden_states, do_classifier_free_guidance, device ): @@ -145,7 +233,8 @@ def encode_image(self, image): def __call__( self, predicted_image_embeddings: torch.Tensor, - text_encoder_hidden_states: torch.Tensor = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, inference_steps: dict = None, guidance_scale: float = 3.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -160,6 +249,21 @@ def __call__( if inference_steps is None: inference_steps = default_inference_steps_b + if negative_prompt is None: + negative_prompt = "" + + if isinstance(prompt, str): + prompt = [prompt] + elif not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif not isinstance(negative_prompt, list) and negative_prompt is not None: + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + text_encoder_hidden_states = self._encode_prompt( + prompt, device, predicted_image_embeddings.size(0), do_classifier_free_guidance, negative_prompt + ) + predicted_image_embeddings, text_encoder_hidden_states = self.check_inputs( predicted_image_embeddings, text_encoder_hidden_states, do_classifier_free_guidance, device ) From 83f87efa63b07b0d4dcae487691a086d990f4736 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 5 Aug 2023 19:34:08 +0200 Subject: [PATCH 062/181] gen_text_encoder and tokenizer --- scripts/convert_wuerstchen.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 1b75d07a186f..6bfb1fd9663f 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -39,6 +39,8 @@ efficient_net.load_state_dict(state_dict["effnet_state_dict"]) # Generator +gen_text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu") +gen_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") generator = DiffNeXt() generator.load_state_dict(state_dict["state_dict"]) @@ -61,6 +63,8 @@ prior_pipeline.save_pretrained("warp-diffusion/WuerstchenPriorPipeline") generator_pipeline = WuerstchenGeneratorPipeline( + text_encoder=gen_text_encoder, + tokenizer=gen_tokenizer, vqgan=vqmodel, generator=generator, efficient_net=efficient_net, From 7a3639d1b63654345a74c21f83e756faa8370a1c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 5 Aug 2023 19:59:07 +0200 Subject: [PATCH 063/181] fix import --- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index cb07790094f2..9f2f261ebb3a 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -17,6 +17,7 @@ import numpy as np import torch +from transformers import CLIPTextModel, CLIPTokenizer from ...models import VQModelPaella from ...schedulers import DDPMWuerstchenScheduler From 3791b94544bb7b76f87a905e7ebeff4eded25f79 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 6 Aug 2023 12:31:35 +0200 Subject: [PATCH 064/181] updated pipeline test --- scripts/wuerstchen_pipeline_test.py | 37 ++++++++++++------- .../wuerstchen/pipeline_wuerstchen.py | 4 +- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/scripts/wuerstchen_pipeline_test.py b/scripts/wuerstchen_pipeline_test.py index 8c7ee8de4ac3..66d4bfb3c115 100644 --- a/scripts/wuerstchen_pipeline_test.py +++ b/scripts/wuerstchen_pipeline_test.py @@ -74,23 +74,25 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: prior_pipeline = prior_pipeline.to("cuda") generator_pipeline = generator_pipeline.to("cuda") # generator_pipeline.vqgan.to(torch.float16) -text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu") -tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") +# text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu") +# tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") # negative_prompt = "low resolution, low detail, bad quality, blurry" negative_prompt = "bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated, drawing, helmet" # negative_prompt = "" -# caption = "Bee flying out of a glass jar in a green and red leafy basket, glass and lens flare, diffuse lighting elegant" -# caption = "princess | centered| key visual| intricate| highly detailed| breathtaking beauty| precise lineart| vibrant| comprehensive cinematic| Carne Griffiths| Conrad Roset" -caption = "An armchair in the shape of an avocado" -clip_tokens = tokenizer( - [caption] * batch_size, - truncation=True, - padding="max_length", - max_length=tokenizer.model_max_length, - return_tensors="pt", +caption = ( + "Bee flying out of a glass jar in a green and red leafy basket, glass and lens flare, diffuse lighting elegant" ) -clip_text_embeddings = text_encoder(**clip_tokens).last_hidden_state.to(dtype).to(device) +# caption = "princess | centered| key visual| intricate| highly detailed| breathtaking beauty| precise lineart| vibrant| comprehensive cinematic| Carne Griffiths| Conrad Roset" +# caption = "An armchair in the shape of an avocado" +# clip_tokens = tokenizer( +# [caption] * batch_size, +# truncation=True, +# padding="max_length", +# max_length=tokenizer.model_max_length, +# return_tensors="pt", +# ) +# clip_text_embeddings = text_encoder(**clip_tokens).last_hidden_state.to(dtype).to(device) # clip_tokens_uncond = tokenizer( # [negative_prompt] * batch_size, # truncation=True, @@ -101,10 +103,17 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: # clip_text_embeddings_uncond = text_encoder(**clip_tokens_uncond).last_hidden_state.to(dtype).to(device) prior_output = prior_pipeline( - caption, guidance_scale=8.0, num_images_per_prompt=batch_size, negative_prompt=negative_prompt + caption, + guidance_scale=8.0, + num_images_per_prompt=batch_size, + negative_prompt=negative_prompt, ) generator_output = generator_pipeline( - prior_output.image_embeds, clip_text_embeddings, guidance_scale=0.0, output_type="np" + predicted_image_embeddings=prior_output.image_embeds, + prompt=caption, + negative_prompt=negative_prompt, + guidance_scale=8.0, + output_type="np", ).images images = numpy_to_pil(generator_output) os.makedirs("samples", exist_ok=True) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 9f2f261ebb3a..75650ddc3cd9 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -302,9 +302,7 @@ def __call__( torch.cat([latents] * 2) if do_classifier_free_guidance else latents, r=torch.cat([ratio] * 2) if do_classifier_free_guidance else ratio, effnet=effnet, - clip=torch.cat([text_encoder_hidden_states] * 2) - if do_classifier_free_guidance - else text_encoder_hidden_states, + clip=text_encoder_hidden_states, ) if do_classifier_free_guidance: From f91b12e172720e0ad9b3283cf67270ba1f147549 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 6 Aug 2023 12:52:24 +0200 Subject: [PATCH 065/181] undo changes to pipeline test --- scripts/convert_wuerstchen.py | 4 +- scripts/wuerstchen_pipeline_test.py | 56 ++++++++----------- .../wuerstchen/pipeline_wuerstchen.py | 1 + .../wuerstchen/pipeline_wuerstchen_prior.py | 3 +- 4 files changed, 29 insertions(+), 35 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 6bfb1fd9663f..78ae402ea6b5 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -7,10 +7,10 @@ from diffusers import ( DDPMWuerstchenScheduler, VQModelPaella, - WuerstchenPriorPipeline, WuerstchenGeneratorPipeline, + WuerstchenPriorPipeline, ) -from diffusers.pipelines.wuerstchen import Prior, DiffNeXt, EfficientNetEncoder +from diffusers.pipelines.wuerstchen import DiffNeXt, EfficientNetEncoder, Prior model_path = "models/" diff --git a/scripts/wuerstchen_pipeline_test.py b/scripts/wuerstchen_pipeline_test.py index 3e916a234b3a..5d5f495ae965 100644 --- a/scripts/wuerstchen_pipeline_test.py +++ b/scripts/wuerstchen_pipeline_test.py @@ -1,10 +1,12 @@ import os + import numpy as np import torch import transformers from PIL import Image -from transformers import AutoTokenizer, CLIPTextModel -from diffusers import WuerstchenPriorPipeline, WuerstchenGeneratorPipeline + +from diffusers import WuerstchenGeneratorPipeline, WuerstchenPriorPipeline + transformers.utils.logging.set_verbosity_error() @@ -68,32 +70,15 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: torch.manual_seed(42) -prior_pipeline = WuerstchenPriorPipeline.from_pretrained("warp-diffusion/WuerstchenPriorPipeline", torch_dtype=dtype) - -# from diffusers import DDPMScheduler -# noise_scheduler = DDPMScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler") -# prior_pipeline.scheduler = noise_scheduler - -generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained("warp-diffusion/WuerstchenGeneratorPipeline", torch_dtype=dtype) - prior_pipeline = WuerstchenPriorPipeline.from_pretrained("warp-diffusion/WuerstchenPriorPipeline", torch_dtype=dtype) generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained( "warp-diffusion/WuerstchenGeneratorPipeline", torch_dtype=dtype ) prior_pipeline = prior_pipeline.to("cuda") generator_pipeline = generator_pipeline.to("cuda") - -def embed_clip(clip_model, clip_tokenizer, caption, negative_caption="", batch_size=4, device="cuda"): - clip_tokens = clip_tokenizer([caption] * batch_size, truncation=True, padding="max_length", max_length=clip_tokenizer.model_max_length, return_tensors="pt").to(device) - clip_text_embeddings = clip_model(**clip_tokens).last_hidden_state - - clip_tokens_uncond = clip_tokenizer([negative_caption] * batch_size, truncation=True, padding="max_length", max_length=clip_tokenizer.model_max_length, return_tensors="pt").to(device) - clip_text_embeddings_uncond = clip_model(**clip_tokens_uncond).last_hidden_state - return clip_text_embeddings, clip_text_embeddings_uncond - # generator_pipeline.vqgan.to(torch.float16) -text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu") -tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") +# text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu") +# tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") # negative_prompt = "low resolution, low detail, bad quality, blurry" negative_prompt = "bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated, drawing, helmet" @@ -102,15 +87,15 @@ def embed_clip(clip_model, clip_tokenizer, caption, negative_caption="", batch_s "Bee flying out of a glass jar in a green and red leafy basket, glass and lens flare, diffuse lighting elegant" ) # caption = "princess | centered| key visual| intricate| highly detailed| breathtaking beauty| precise lineart| vibrant| comprehensive cinematic| Carne Griffiths| Conrad Roset" -caption = "An armchair in the shape of an avocado" -clip_tokens = tokenizer( - [caption] * batch_size, - truncation=True, - padding="max_length", - max_length=tokenizer.model_max_length, - return_tensors="pt", -) -clip_text_embeddings = text_encoder(**clip_tokens).last_hidden_state.to(dtype).to(device) +# caption = "An armchair in the shape of an avocado" +# clip_tokens = tokenizer( +# [caption] * batch_size, +# truncation=True, +# padding="max_length", +# max_length=tokenizer.model_max_length, +# return_tensors="pt", +# ) +# clip_text_embeddings = text_encoder(**clip_tokens).last_hidden_state.to(dtype).to(device) # clip_tokens_uncond = tokenizer( # [negative_prompt] * batch_size, # truncation=True, @@ -121,10 +106,17 @@ def embed_clip(clip_model, clip_tokenizer, caption, negative_caption="", batch_s # clip_text_embeddings_uncond = text_encoder(**clip_tokens_uncond).last_hidden_state.to(dtype).to(device) prior_output = prior_pipeline( - caption, guidance_scale=8.0, num_images_per_prompt=batch_size, negative_prompt=negative_prompt + caption, + guidance_scale=8.0, + num_images_per_prompt=batch_size, + negative_prompt=negative_prompt, ) generator_output = generator_pipeline( - prior_output.image_embeds, clip_text_embeddings, guidance_scale=0.0, output_type="np" + predicted_image_embeddings=prior_output.image_embeds, + prompt=caption, + negative_prompt=negative_prompt, + guidance_scale=8.0, + output_type="np", ).images images = numpy_to_pil(generator_output) os.makedirs("samples", exist_ok=True) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 75650ddc3cd9..0c730ffbf69e 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -25,6 +25,7 @@ from ..pipeline_utils import DiffusionPipeline from .modules import DiffNeXt, EfficientNetEncoder + # from .diffuzz import Diffuzz diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 453bf9d2e8d9..bbda54e61cce 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -24,6 +24,7 @@ from ..pipeline_utils import DiffusionPipeline from .prior import Prior + # from .diffuzz import Diffuzz @@ -235,7 +236,7 @@ def __call__( text_encoder_hidden_states = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) - + dtype = text_encoder_hidden_states.dtype latent_height = 128 * (height // 128) // (1024 // 24) latent_width = 128 * (width // 128) // (1024 // 24) From 5b518a217431a5eee3c471b813174d2660cb2b62 Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Sun, 6 Aug 2023 12:58:01 +0200 Subject: [PATCH 066/181] nip --- scripts/wuerstchen_pipeline_test.py | 47 +++++++++++------------------ 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/scripts/wuerstchen_pipeline_test.py b/scripts/wuerstchen_pipeline_test.py index 3e916a234b3a..b2b5d6f6d549 100644 --- a/scripts/wuerstchen_pipeline_test.py +++ b/scripts/wuerstchen_pipeline_test.py @@ -68,14 +68,6 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: torch.manual_seed(42) -prior_pipeline = WuerstchenPriorPipeline.from_pretrained("warp-diffusion/WuerstchenPriorPipeline", torch_dtype=dtype) - -# from diffusers import DDPMScheduler -# noise_scheduler = DDPMScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler") -# prior_pipeline.scheduler = noise_scheduler - -generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained("warp-diffusion/WuerstchenGeneratorPipeline", torch_dtype=dtype) - prior_pipeline = WuerstchenPriorPipeline.from_pretrained("warp-diffusion/WuerstchenPriorPipeline", torch_dtype=dtype) generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained( "warp-diffusion/WuerstchenGeneratorPipeline", torch_dtype=dtype @@ -83,34 +75,31 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: prior_pipeline = prior_pipeline.to("cuda") generator_pipeline = generator_pipeline.to("cuda") -def embed_clip(clip_model, clip_tokenizer, caption, negative_caption="", batch_size=4, device="cuda"): - clip_tokens = clip_tokenizer([caption] * batch_size, truncation=True, padding="max_length", max_length=clip_tokenizer.model_max_length, return_tensors="pt").to(device) - clip_text_embeddings = clip_model(**clip_tokens).last_hidden_state +# def embed_clip(clip_model, clip_tokenizer, caption, negative_caption="", batch_size=4, device="cuda"): +# clip_tokens = clip_tokenizer([caption] * batch_size, truncation=True, padding="max_length", max_length=clip_tokenizer.model_max_length, return_tensors="pt").to(device) +# clip_text_embeddings = clip_model(**clip_tokens).last_hidden_state - clip_tokens_uncond = clip_tokenizer([negative_caption] * batch_size, truncation=True, padding="max_length", max_length=clip_tokenizer.model_max_length, return_tensors="pt").to(device) - clip_text_embeddings_uncond = clip_model(**clip_tokens_uncond).last_hidden_state - return clip_text_embeddings, clip_text_embeddings_uncond +# clip_tokens_uncond = clip_tokenizer([negative_caption] * batch_size, truncation=True, padding="max_length", max_length=clip_tokenizer.model_max_length, return_tensors="pt").to(device) +# clip_text_embeddings_uncond = clip_model(**clip_tokens_uncond).last_hidden_state +# return clip_text_embeddings, clip_text_embeddings_uncond # generator_pipeline.vqgan.to(torch.float16) -text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu") -tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") +# text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu") +# tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") # negative_prompt = "low resolution, low detail, bad quality, blurry" negative_prompt = "bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated, drawing, helmet" # negative_prompt = "" -caption = ( - "Bee flying out of a glass jar in a green and red leafy basket, glass and lens flare, diffuse lighting elegant" -) +caption = "Bee flying out of a glass jar in a green and red leafy basket, glass and lens flare, diffuse lighting elegant" # caption = "princess | centered| key visual| intricate| highly detailed| breathtaking beauty| precise lineart| vibrant| comprehensive cinematic| Carne Griffiths| Conrad Roset" -caption = "An armchair in the shape of an avocado" -clip_tokens = tokenizer( - [caption] * batch_size, - truncation=True, - padding="max_length", - max_length=tokenizer.model_max_length, - return_tensors="pt", -) -clip_text_embeddings = text_encoder(**clip_tokens).last_hidden_state.to(dtype).to(device) +# clip_tokens = tokenizer( +# [caption] * batch_size, +# truncation=True, +# padding="max_length", +# max_length=tokenizer.model_max_length, +# return_tensors="pt", +# ) +# clip_text_embeddings = text_encoder(**clip_tokens).last_hidden_state.to(dtype).to(device) # clip_tokens_uncond = tokenizer( # [negative_prompt] * batch_size, # truncation=True, @@ -124,7 +113,7 @@ def embed_clip(clip_model, clip_tokenizer, caption, negative_caption="", batch_s caption, guidance_scale=8.0, num_images_per_prompt=batch_size, negative_prompt=negative_prompt ) generator_output = generator_pipeline( - prior_output.image_embeds, clip_text_embeddings, guidance_scale=0.0, output_type="np" + prior_output.image_embeds, caption, negative_prompt=negative_prompt, guidance_scale=0.0, output_type="np" ).images images = numpy_to_pil(generator_output) os.makedirs("samples", exist_ok=True) From 55ba4db8b7f240bb4c35455f8c9ed9ba7f7a367c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 6 Aug 2023 16:22:09 +0200 Subject: [PATCH 067/181] fix typo --- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index bbda54e61cce..0cb2c0ed3957 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -134,7 +134,7 @@ def _encode_prompt( text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask - # untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) From 48034a001cc112c5a55295fd07d3dfec7446b8f0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 6 Aug 2023 16:23:36 +0200 Subject: [PATCH 068/181] fix output name --- .../wuerstchen/pipeline_wuerstchen.py | 2 +- .../wuerstchen/pipeline_wuerstchen_prior.py | 2 +- .../schedulers/scheduling_ddpm_wuerstchen.py | 28 ++++--------------- 3 files changed, 7 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 0c730ffbf69e..48c37808d2a4 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -315,7 +315,7 @@ def __call__( timestep=ratio, sample=latents, generator=generator, - ).prediction + ).prev_sample images = self.vqgan.decode(latents).sample.clamp(0, 1) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 0cb2c0ed3957..3b0d54da16a3 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -277,7 +277,7 @@ def __call__( timestep=ratio, sample=latents, generator=generator, - ).prediction + ).prev_sample # t_start = 1.0 # for t_end, steps in inference_steps.items(): diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py index 38ec8e37c0d4..36b2b0b767d0 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py +++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py @@ -35,12 +35,9 @@ class DDPMWuerstchenSchedulerOutput(BaseOutput): prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. - pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - The predicted denoised sample (x_{0}) based on the model output from the current timestep. - `pred_original_sample` can be used to preview progress or for guidance. """ - prediction: torch.FloatTensor + prev_sample: torch.FloatTensor def betas_for_alpha_bar( @@ -165,7 +162,6 @@ def set_timesteps( timesteps = torch.cat([timesteps, steps[1:]]) self.timesteps = timesteps - # print(f"Timesteps: {self.timesteps}, Timesteps Shape: {timesteps.shape}") def step( self, @@ -202,22 +198,18 @@ def step( alpha_cumprod = self._alpha_cumprod(t, device).view(t.size(0), *[1 for _ in sample.shape[1:]]) alpha_cumprod_prev = self._alpha_cumprod(prev_t, device).view(prev_t.size(0), *[1 for _ in sample.shape[1:]]) - alpha = (alpha_cumprod / alpha_cumprod_prev) - # print(f"scheduler: {alpha}") + alpha = alpha_cumprod / alpha_cumprod_prev mu = (1.0 / alpha).sqrt() * (sample - (1 - alpha) * model_output / (1 - alpha_cumprod).sqrt()) - # print(f"scheduler: {mu.mean()}") - # torch.manual_seed(0) + std_noise = randn_tensor(mu.shape, generator=generator, device=model_output.device, dtype=model_output.dtype) - # std_noise = torch.randn_like(mu) - std = ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * std_noise - # print(f"scheduler: {std.mean()}") + std = ((1 - alpha) * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod)).sqrt() * std_noise pred = mu + std * (prev_t != 0).float().view(prev_t.size(0), *[1 for _ in sample.shape[1:]]) if not return_dict: return (pred.to(dtype),) - return DDPMWuerstchenSchedulerOutput(prediction=pred.to(dtype)) + return DDPMWuerstchenSchedulerOutput(prev_sample=pred.to(dtype)) def add_noise( self, @@ -246,16 +238,6 @@ def __len__(self): return self.config.num_train_timesteps def previous_timestep(self, timestep): - # print(f"Timestep Shape: {timestep.shape}") - # print(timestep) - # print((self.timesteps == timestep[0]).nonzero()) - # index = (self.timesteps == timestep[0]).nonzero().item() index = (self.timesteps - timestep[0]).abs().argmin().item() - # print(f"Found index at {index}") - # print(self.timesteps[index + 1]) prev_t = self.timesteps[index + 1][None].expand(timestep.shape[0]) - # print(prev_t.shape, prev_t) - # print(timestep) - # print(prev_t) - # print("======== ===================================================================") return prev_t From be2529ad746209d6258057c8e7ef41a756503317 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 6 Aug 2023 16:29:48 +0200 Subject: [PATCH 069/181] set guidance_scale=0 and remove diffuze --- scripts/wuerstchen_pipeline_test.py | 6 +- src/diffusers/pipelines/wuerstchen/diffuzz.py | 122 ------------------ .../wuerstchen/pipeline_wuerstchen.py | 2 - .../wuerstchen/pipeline_wuerstchen_prior.py | 1 - 4 files changed, 4 insertions(+), 127 deletions(-) delete mode 100644 src/diffusers/pipelines/wuerstchen/diffuzz.py diff --git a/scripts/wuerstchen_pipeline_test.py b/scripts/wuerstchen_pipeline_test.py index ea0f71c9f5b8..b2d2074e00f2 100644 --- a/scripts/wuerstchen_pipeline_test.py +++ b/scripts/wuerstchen_pipeline_test.py @@ -83,7 +83,9 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: # negative_prompt = "low resolution, low detail, bad quality, blurry" negative_prompt = "bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated, drawing, helmet" # negative_prompt = "" -caption = "Bee flying out of a glass jar in a green and red leafy basket, glass and lens flare, diffuse lighting elegant" +caption = ( + "Bee flying out of a glass jar in a green and red leafy basket, glass and lens flare, diffuse lighting elegant" +) # caption = "princess | centered| key visual| intricate| highly detailed| breathtaking beauty| precise lineart| vibrant| comprehensive cinematic| Carne Griffiths| Conrad Roset" # caption = "An armchair in the shape of an avocado" # clip_tokens = tokenizer( @@ -113,7 +115,7 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: predicted_image_embeddings=prior_output.image_embeds, prompt=caption, negative_prompt=negative_prompt, - guidance_scale=8.0, + guidance_scale=0.0, output_type="np", ).images images = numpy_to_pil(generator_output) diff --git a/src/diffusers/pipelines/wuerstchen/diffuzz.py b/src/diffusers/pipelines/wuerstchen/diffuzz.py deleted file mode 100644 index 79e1760c3fc6..000000000000 --- a/src/diffusers/pipelines/wuerstchen/diffuzz.py +++ /dev/null @@ -1,122 +0,0 @@ -import torch - - -# Samplers -------------------------------------------------------------------- -class SimpleSampler(): - def __init__(self, diffuzz): - self.current_step = -1 - self.diffuzz = diffuzz - - def __call__(self, *args, **kwargs): - self.current_step += 1 - return self.step(*args, **kwargs) - - def init_x(self, shape): - return torch.randn(*shape, device=self.diffuzz.device) - - def step(self, x, t, t_prev, noise): - raise NotImplementedError("You should override the 'apply' function.") - - -class DDPMSampler(SimpleSampler): - def step(self, x, t, t_prev, noise): - alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) - alpha_cumprod_prev = self.diffuzz._alpha_cumprod(t_prev).view(t_prev.size(0), *[1 for _ in x.shape[1:]]) - alpha = (alpha_cumprod / alpha_cumprod_prev) - - mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt()) - std = ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * torch.randn_like(mu) - return mu + std * (t_prev != 0).float().view(t_prev.size(0), *[1 for _ in x.shape[1:]]) - - -class DDIMSampler(SimpleSampler): - def step(self, x, t, t_prev, noise): - alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) - alpha_cumprod_prev = self.diffuzz._alpha_cumprod(t_prev).view(t_prev.size(0), *[1 for _ in x.shape[1:]]) - - x0 = (x - (1 - alpha_cumprod).sqrt() * noise) / (alpha_cumprod).sqrt() - dp_xt = (1 - alpha_cumprod_prev).sqrt() - return (alpha_cumprod_prev).sqrt() * x0 + dp_xt * noise - - -sampler_dict = { - 'ddpm': DDPMSampler, - 'ddim': DDIMSampler, -} - - -# Custom simplified foward/backward diffusion (cosine schedule) -class Diffuzz(): - def __init__(self, s=0.008, device="cpu", cache_steps=None, scaler=1): - self.device = device - self.s = torch.tensor([s]).to(device) - self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 - self.scaler = scaler - self.cached_steps = None - if cache_steps is not None: - self.cached_steps = self._alpha_cumprod(torch.linspace(0, 1, cache_steps, device=device)) - - def _alpha_cumprod(self, t): - if self.cached_steps is None: - if self.scaler > 1: - t = 1 - (1 - t) ** self.scaler - elif self.scaler < 1: - t = t ** self.scaler - alpha_cumprod = torch.cos((t + self.s) / (1 + self.s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod - return alpha_cumprod.clamp(0.0001, 0.9999) - else: - return self.cached_steps[t.mul(len(self.cached_steps) - 1).long()] - - def diffuse(self, x, t, noise=None): # t -> [0, 1] - if noise is None: - noise = torch.randn_like(x) - alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) - return alpha_cumprod.sqrt() * x + (1 - alpha_cumprod).sqrt() * noise, noise - - def undiffuse(self, x, t, t_prev, noise, sampler=None): - if sampler is None: - sampler = DDPMSampler(self) - return sampler(x, t, t_prev, noise) - - def sample(self, model, model_inputs, shape, mask=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, - unconditional_inputs=None, sampler='ddpm', half=False): - r_range = torch.linspace(t_start, t_end, timesteps + 1)[:, None].expand(-1, shape[0] if x_init is None else x_init.size(0)).to(self.device) - if isinstance(sampler, str): - if sampler in sampler_dict: - sampler = sampler_dict[sampler](self) - else: - raise ValueError( - f"If sampler is a string it must be one of the supported samplers: {list(sampler_dict.keys())}") - elif issubclass(sampler, SimpleSampler): - sampler = sampler(self) - else: - raise ValueError("Sampler should be either a string or a SimpleSampler object.") - preds = [] - x = sampler.init_x(shape) if x_init is None or mask is not None else x_init.clone() - if half: - r_range = r_range.half() - x = x.half() - if cfg is not None: - if unconditional_inputs is None: - unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()} - model_inputs = {k: torch.cat([v, v_u]) for (k, v), (k_u, v_u) in - zip(model_inputs.items(), unconditional_inputs.items())} - for i in range(0, timesteps): - if mask is not None and x_init is not None: - x_renoised, _ = self.diffuse(x_init, r_range[i]) - x = x * mask + x_renoised * (1 - mask) - - if cfg is not None: - pred_noise, pred_noise_unconditional = model(torch.cat([x] * 2), torch.cat([r_range[i]] * 2), - **model_inputs).chunk(2) - pred_noise = torch.lerp(pred_noise_unconditional, pred_noise, cfg) - else: - pred_noise = model(x, r_range[i], **model_inputs) - - x = self.undiffuse(x, r_range[i], r_range[i + 1], pred_noise, sampler=sampler) - preds.append(x) - return preds - - def p2_weight(self, t, k=1.0, gamma=1.0): - alpha_cumprod = self._alpha_cumprod(t) - return (k + alpha_cumprod / (1 - alpha_cumprod)) ** -gamma \ No newline at end of file diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 48c37808d2a4..4b08fe1a0df8 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -100,8 +100,6 @@ def __init__( vqgan=vqgan, efficient_net=efficient_net, ) - # self.diffuzz = Diffuzz(device="cuda") - self.register_to_config() def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 3b0d54da16a3..de24ec931932 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -101,7 +101,6 @@ def __init__( prior=prior, scheduler=scheduler, ) - # self.diffuzz = Diffuzz(device="cuda") self.register_to_config() def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): From be1aa9656323f0e66406ce5f7629598b359df3d0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 6 Aug 2023 16:32:01 +0200 Subject: [PATCH 070/181] fix doc strings --- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 6 +++++- .../pipelines/wuerstchen/pipeline_wuerstchen_prior.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 4b08fe1a0df8..d650f5d452c9 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -71,13 +71,17 @@ class WuerstchenGeneratorPipeline(DiffusionPipeline): library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: + tokenizer (`CLIPTokenizer`): + The CLIP tokenizer. + text_encoder (`CLIPTextModel`): + The CLIP text encoder. generator ([`DiffNeXt`]): The DiffNeXt unet generator. vqgan ([`VQModelPaella`]): The VQGAN model. efficient_net ([`EfficientNetEncoder`]): The EfficientNet encoder. - scheduler ([`DDPMScheduler`]): + scheduler ([`DDPMWuerstchenScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. """ diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index de24ec931932..ae66d467fa79 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -82,7 +82,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline): tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - scheduler ([`DDPMScheduler`]): + scheduler ([`DDPMWuerstchenScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. """ From a1114e3ef642b2e7b96f30a2ef165fa2a43ba7b7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 6 Aug 2023 16:36:46 +0200 Subject: [PATCH 071/181] make style --- scripts/diffuzz.py | 126 --------------- scripts/vqgan.py | 144 ------------------ src/diffusers/__init__.py | 4 +- src/diffusers/models/unet_2d_blocks.py | 12 -- .../pipelines/wuerstchen/__init__.py | 4 +- .../wuerstchen/pipeline_wuerstchen.py | 10 +- .../wuerstchen/pipeline_wuerstchen_prior.py | 8 +- src/diffusers/schedulers/__init__.py | 2 +- .../schedulers/scheduling_ddpm_wuerstchen.py | 3 +- 9 files changed, 19 insertions(+), 294 deletions(-) delete mode 100644 scripts/diffuzz.py delete mode 100644 scripts/vqgan.py diff --git a/scripts/diffuzz.py b/scripts/diffuzz.py deleted file mode 100644 index 9f80a406b952..000000000000 --- a/scripts/diffuzz.py +++ /dev/null @@ -1,126 +0,0 @@ -import torch - - -# Samplers -------------------------------------------------------------------- -class SimpleSampler(): - def __init__(self, diffuzz): - self.current_step = -1 - self.diffuzz = diffuzz - - def __call__(self, *args, **kwargs): - self.current_step += 1 - return self.step(*args, **kwargs) - - def init_x(self, shape): - return torch.randn(*shape, device=self.diffuzz.device) - - def step(self, x, t, t_prev, noise): - raise NotImplementedError("You should override the 'apply' function.") - - -class DDPMSampler(SimpleSampler): - def step(self, x, t, t_prev, noise): - alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) - alpha_cumprod_prev = self.diffuzz._alpha_cumprod(t_prev).view(t_prev.size(0), *[1 for _ in x.shape[1:]]) - alpha = (alpha_cumprod / alpha_cumprod_prev) - print(f"diffuzz: {alpha}") - - mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt()) - print(f"diffuzz: {mu.mean()}") - torch.manual_seed(0) - std = ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * torch.randn_like(mu) - print(f"diffuzz: {std.mean()}") - return mu + std * (t_prev != 0).float().view(t_prev.size(0), *[1 for _ in x.shape[1:]]) - - -class DDIMSampler(SimpleSampler): - def step(self, x, t, t_prev, noise): - alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) - alpha_cumprod_prev = self.diffuzz._alpha_cumprod(t_prev).view(t_prev.size(0), *[1 for _ in x.shape[1:]]) - - x0 = (x - (1 - alpha_cumprod).sqrt() * noise) / (alpha_cumprod).sqrt() - dp_xt = (1 - alpha_cumprod_prev).sqrt() - return (alpha_cumprod_prev).sqrt() * x0 + dp_xt * noise - - -sampler_dict = { - 'ddpm': DDPMSampler, - 'ddim': DDIMSampler, -} - - -# Custom simplified foward/backward diffusion (cosine schedule) -class Diffuzz(): - def __init__(self, s=0.008, device="cpu", cache_steps=None, scaler=1): - self.device = device - self.s = torch.tensor([s]).to(device) - self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 - self.scaler = scaler - self.cached_steps = None - if cache_steps is not None: - self.cached_steps = self._alpha_cumprod(torch.linspace(0, 1, cache_steps, device=device)) - - def _alpha_cumprod(self, t): - if self.cached_steps is None: - if self.scaler > 1: - t = 1 - (1 - t) ** self.scaler - elif self.scaler < 1: - t = t ** self.scaler - alpha_cumprod = torch.cos((t + self.s) / (1 + self.s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod - return alpha_cumprod.clamp(0.0001, 0.9999) - else: - return self.cached_steps[t.mul(len(self.cached_steps) - 1).long()] - - def diffuse(self, x, t, noise=None): # t -> [0, 1] - if noise is None: - noise = torch.randn_like(x) - alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) - return alpha_cumprod.sqrt() * x + (1 - alpha_cumprod).sqrt() * noise, noise - - def undiffuse(self, x, t, t_prev, noise, sampler=None): - if sampler is None: - sampler = DDPMSampler(self) - return sampler(x, t, t_prev, noise) - - def sample(self, model, model_inputs, shape, mask=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, - unconditional_inputs=None, sampler='ddpm', half=False): - r_range = torch.linspace(t_start, t_end, timesteps + 1)[:, None].expand(-1, shape[0] if x_init is None else x_init.size(0)).to(self.device) - if isinstance(sampler, str): - if sampler in sampler_dict: - sampler = sampler_dict[sampler](self) - else: - raise ValueError( - f"If sampler is a string it must be one of the supported samplers: {list(sampler_dict.keys())}") - elif issubclass(sampler, SimpleSampler): - sampler = sampler(self) - else: - raise ValueError("Sampler should be either a string or a SimpleSampler object.") - preds = [] - x = sampler.init_x(shape) if x_init is None or mask is not None else x_init.clone() - if half: - r_range = r_range.half() - x = x.half() - if cfg is not None: - if unconditional_inputs is None: - unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()} - model_inputs = {k: torch.cat([v, v_u]) for (k, v), (k_u, v_u) in - zip(model_inputs.items(), unconditional_inputs.items())} - for i in range(0, timesteps): - if mask is not None and x_init is not None: - x_renoised, _ = self.diffuse(x_init, r_range[i]) - x = x * mask + x_renoised * (1 - mask) - - if cfg is not None: - pred_noise, pred_noise_unconditional = model(torch.cat([x] * 2), torch.cat([r_range[i]] * 2), - **model_inputs).chunk(2) - pred_noise = torch.lerp(pred_noise_unconditional, pred_noise, cfg) - else: - pred_noise = model(x, r_range[i], **model_inputs) - - x = self.undiffuse(x, r_range[i], r_range[i + 1], pred_noise, sampler=sampler) - preds.append(x) - return preds - - def p2_weight(self, t, k=1.0, gamma=1.0): - alpha_cumprod = self._alpha_cumprod(t) - return (k + alpha_cumprod / (1 - alpha_cumprod)) ** -gamma \ No newline at end of file diff --git a/scripts/vqgan.py b/scripts/vqgan.py deleted file mode 100644 index 935023e1ff6e..000000000000 --- a/scripts/vqgan.py +++ /dev/null @@ -1,144 +0,0 @@ -import torch -from torch import nn -import numpy as np -import math -from tqdm import tqdm -import time -from torchtools.nn import VectorQuantize - -class ResBlock(nn.Module): - def __init__(self, c, c_hidden): - super().__init__() - # depthwise/attention - self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) - self.depthwise = nn.Sequential( - nn.ReplicationPad2d(1), - nn.Conv2d(c, c, kernel_size=3, groups=c) - ) - - # channelwise - self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) - self.channelwise = nn.Sequential( - nn.Linear(c, c_hidden), - nn.GELU(), - nn.Linear(c_hidden, c), - ) - - self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) - - # Init weights - def _basic_init(module): - if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - self.apply(_basic_init) - - - def _norm(self, x, norm): - return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - - def forward(self, x): - mods = self.gammas - - x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] - x = x + self.depthwise(x_temp) * mods[2] - - x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] - x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] - - return x - -class VQModel(nn.Module): - def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, scale_factor=0.3764): # 1.0 - super().__init__() - self.c_latent = c_latent - self.scale_factor = scale_factor - c_levels = [c_hidden//(2**i) for i in reversed(range(levels))] - - # Encoder blocks - self.in_block = nn.Sequential( - nn.PixelUnshuffle(2), - nn.Conv2d(3*4, c_levels[0], kernel_size=1) - ) - down_blocks = [] - for i in range(levels): - if i > 0: - down_blocks.append(nn.Conv2d(c_levels[i-1], c_levels[i], kernel_size=4, stride=2, padding=1)) - block = ResBlock(c_levels[i], c_levels[i]*4) - down_blocks.append(block) - down_blocks.append(nn.Sequential( - nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), - nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 - )) - self.down_blocks = nn.Sequential(*down_blocks) - self.down_blocks[0] - - self.codebook_size = codebook_size - self.vquantizer = VectorQuantize(c_latent, k=codebook_size) - - # Decoder blocks - up_blocks = [nn.Sequential( - nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) - )] - for i in range(levels): - for j in range(bottleneck_blocks if i == 0 else 1): - block = ResBlock(c_levels[levels-1-i], c_levels[levels-1-i]*4) - up_blocks.append(block) - if i < levels-1: - up_blocks.append(nn.ConvTranspose2d(c_levels[levels-1-i], c_levels[levels-2-i], kernel_size=4, stride=2, padding=1)) - self.up_blocks = nn.Sequential(*up_blocks) - self.out_block = nn.Sequential( - nn.Conv2d(c_levels[0], 3*4, kernel_size=1), - nn.PixelShuffle(2), - ) - - def encode(self, x): - x = self.in_block(x) - x = self.down_blocks(x) - qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) - return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 - - def decode(self, x): - x = x * self.scale_factor - x = self.up_blocks(x) - x = self.out_block(x) - return x - - def decode_indices(self, x): - x = self.vquantizer.idx2vq(x, dim=1) - x = self.up_blocks(x) - x = self.out_block(x) - return x - - def forward(self, x, quantize=False): - qe, x, _, vq_loss = self.encode(x, quantize) - x = self.decode(qe) - return x, vq_loss - -class Discriminator(nn.Module): - def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6): - super().__init__() - d = max(depth - 3, 3) - layers = [ - nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), - nn.LeakyReLU(0.2), - ] - for i in range(depth - 1): - c_in = c_hidden // (2 ** max((d - i), 0)) - c_out = c_hidden // (2 ** max((d - 1 - i), 0)) - layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) - layers.append(nn.InstanceNorm2d(c_out)) - layers.append(nn.LeakyReLU(0.2)) - self.encoder = nn.Sequential(*layers) - self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) - self.logits = nn.Sigmoid() - - def forward(self, x, cond=None): - x = self.encoder(x) - if cond is not None: - cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) - x = torch.cat([x, cond], dim=1) - x = self.shuffle(x) - x = self.logits(x) - return x \ No newline at end of file diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0166c0b8c4ba..6ed7a9f5d16e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -41,7 +41,6 @@ AutoencoderTiny, ControlNetModel, ModelMixin, - VQModelPaella, MultiAdapter, PriorTransformer, T2IAdapter, @@ -52,6 +51,7 @@ UNet2DModel, UNet3DConditionModel, VQModel, + VQModelPaella, ) from .optimization import ( get_constant_schedule, @@ -87,8 +87,8 @@ DDIMParallelScheduler, DDIMScheduler, DDPMParallelScheduler, - DDPMWuerstchenScheduler, DDPMScheduler, + DDPMWuerstchenScheduler, DEISMultistepScheduler, DPMSolverMultistepInverseScheduler, DPMSolverMultistepScheduler, diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 0152678663c9..68ad1e2a29ac 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -27,7 +27,6 @@ Downsample2D, FirDownsample2D, FirUpsample2D, - GlobalResponseResidualBlock, KDownsample2D, KUpsample2D, ResnetBlock2D, @@ -433,17 +432,6 @@ def get_up_block( raise ValueError(f"{up_block_type} does not exist.") -def get_paella_block(block_type, c_hidden, nhead, c_cond, c_r, kernel_size=3, c_skip=0, dropout=0, self_attn=True): - if block_type == "C": - return GlobalResponseResidualBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) - elif block_type == "A": - return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) - elif block_type == "T": - return TimestepBlock(c_hidden, c_r) - else: - raise ValueError(f"'Block type {block_type} not supported.") - - class AutoencoderTinyBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int, act_fn: str): super().__init__() diff --git a/src/diffusers/pipelines/wuerstchen/__init__.py b/src/diffusers/pipelines/wuerstchen/__init__.py index 4cf1d2bcda88..54a7dee3b533 100644 --- a/src/diffusers/pipelines/wuerstchen/__init__.py +++ b/src/diffusers/pipelines/wuerstchen/__init__.py @@ -3,6 +3,6 @@ if is_transformers_available() and is_torch_available(): from .modules import DiffNeXt, EfficientNetEncoder - from .prior import Prior from .pipeline_wuerstchen import WuerstchenGeneratorPipeline - from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline + from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline + from .prior import Prior diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index d650f5d452c9..abe93dcfd513 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -37,8 +37,12 @@ >>> import torch >>> from diffusers import WuerstchenPriorPipeline, WuerstchenGeneratorPipeline - >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained("kashif/wuerstchen-prior", torch_dtype=torch.float16).to("cuda") - >>> gen_pipe = WuerstchenGeneratorPipeline.from_pretrain("kashif/wuerstchen-gen", torch_dtype=torch.float16).to("cuda") + >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained( + ... "kashif/wuerstchen-prior", torch_dtype=torch.float16 + ... ).to("cuda") + >>> gen_pipe = WuerstchenGeneratorPipeline.from_pretrain( + ... "kashif/wuerstchen-gen", torch_dtype=torch.float16 + ... ).to("cuda") >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" >>> prior_output = pipe(prompt) @@ -65,7 +69,7 @@ class WuerstchenGeneratorPipelineOutput(BaseOutput): class WuerstchenGeneratorPipeline(DiffusionPipeline): """ - Pipeline for generating images from the Wuerstchen model. + Pipeline for generating images from the Wuerstchen model. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index ae66d467fa79..2d4ece9daeb6 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -36,8 +36,12 @@ >>> import torch >>> from diffusers import WuerstchenPriorPipeline, WuerstchenGeneratorPipeline - >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained("kashif/wuerstchen-prior", torch_dtype=torch.float16).to("cuda") - >>> gen_pipe = WuerstchenGeneratorPipeline.from_pretrain("kashif/wuerstchen-gen", torch_dtype=torch.float16).to("cuda") + >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained( + ... "kashif/wuerstchen-prior", torch_dtype=torch.float16 + ... ).to("cuda") + >>> gen_pipe = WuerstchenGeneratorPipeline.from_pretrain( + ... "kashif/wuerstchen-gen", torch_dtype=torch.float16 + ... ).to("cuda") >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" >>> prior_output = pipe(prompt) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index fda44eb032ab..84df4ffb84db 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -34,6 +34,7 @@ from .scheduling_ddim_parallel import DDIMParallelScheduler from .scheduling_ddpm import DDPMScheduler from .scheduling_ddpm_parallel import DDPMParallelScheduler + from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler from .scheduling_deis_multistep import DEISMultistepScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler @@ -53,7 +54,6 @@ from .scheduling_unipc_multistep import UniPCMultistepScheduler from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_vq_diffusion import VQDiffusionScheduler - from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler try: if not is_flax_available(): diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py index 36b2b0b767d0..9da0f283251c 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py +++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py @@ -18,12 +18,11 @@ from dataclasses import dataclass from typing import List, Optional, Tuple, Union -import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, randn_tensor -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin +from .scheduling_utils import SchedulerMixin @dataclass From 596c7f552b875c578c1fa4b2f9fdec9114446790 Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Sun, 6 Aug 2023 16:52:27 +0200 Subject: [PATCH 072/181] nip --- .gitignore | 1 + scripts/convert_wuerstchen_prior.py | 37 +++++++++++++++++++ scripts/wuerstchen_pipeline_test.py | 10 +++-- .../wuerstchen/pipeline_wuerstchen_prior.py | 2 +- 4 files changed, 45 insertions(+), 5 deletions(-) create mode 100644 scripts/convert_wuerstchen_prior.py diff --git a/.gitignore b/.gitignore index 5b7f6b1ec1de..776df19820d3 100644 --- a/.gitignore +++ b/.gitignore @@ -176,6 +176,7 @@ tags wandb scripts/models/ scripts/warp-diffusion/ +scripts/warp-diffusion-test/ scripts/samples/ scripts/text_encoding_colab.pt scripts/text_encoding.pt diff --git a/scripts/convert_wuerstchen_prior.py b/scripts/convert_wuerstchen_prior.py new file mode 100644 index 000000000000..db0ee9100b1f --- /dev/null +++ b/scripts/convert_wuerstchen_prior.py @@ -0,0 +1,37 @@ +import os + +import torch +from transformers import AutoTokenizer, CLIPTextModel + +from diffusers import ( + DDPMWuerstchenScheduler, + WuerstchenPriorPipeline, +) +from diffusers.pipelines.wuerstchen import Prior + + +model_path = "models/" +device = "cpu" + +# Clip Text encoder and tokenizer +text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") +tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + +# Prior +state_dict = torch.load(os.path.join(model_path, "model_v2_stage_c.pt"), map_location=device) +prior_model = Prior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device) +prior_model.load_state_dict(state_dict["ema_state_dict"]) + +# scheduler +scheduler = DDPMWuerstchenScheduler() + +# Prior pipeline +prior_pipeline = WuerstchenPriorPipeline( + prior=prior_model, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, +) + +prior_pipeline.save_pretrained("warp-diffusion/WuerstchenPriorPipeline") + diff --git a/scripts/wuerstchen_pipeline_test.py b/scripts/wuerstchen_pipeline_test.py index ea0f71c9f5b8..97ca76d19cbb 100644 --- a/scripts/wuerstchen_pipeline_test.py +++ b/scripts/wuerstchen_pipeline_test.py @@ -42,7 +42,7 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: # ) device = "cuda" dtype = torch.float16 -batch_size = 4 +batch_size = 2 # generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\warp-diffusion\\WuerstchenGeneratorPipeline", torch_dtype=dtype) # generator_pipeline = generator_pipeline.to("cuda") @@ -83,9 +83,9 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: # negative_prompt = "low resolution, low detail, bad quality, blurry" negative_prompt = "bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated, drawing, helmet" # negative_prompt = "" -caption = "Bee flying out of a glass jar in a green and red leafy basket, glass and lens flare, diffuse lighting elegant" +# caption = "Bee flying out of a glass jar in a green and red leafy basket, glass and lens flare, diffuse lighting elegant" # caption = "princess | centered| key visual| intricate| highly detailed| breathtaking beauty| precise lineart| vibrant| comprehensive cinematic| Carne Griffiths| Conrad Roset" -# caption = "An armchair in the shape of an avocado" +caption = "An armchair in the shape of an avocado" # clip_tokens = tokenizer( # [caption] * batch_size, # truncation=True, @@ -105,6 +105,8 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: prior_output = prior_pipeline( caption, + height=1024, + width=1536, guidance_scale=8.0, num_images_per_prompt=batch_size, negative_prompt=negative_prompt, @@ -113,7 +115,7 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: predicted_image_embeddings=prior_output.image_embeds, prompt=caption, negative_prompt=negative_prompt, - guidance_scale=8.0, + guidance_scale=0.0, output_type="np", ).images images = numpy_to_pil(generator_output) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index bbda54e61cce..0cb2c0ed3957 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -134,7 +134,7 @@ def _encode_prompt( text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask - # untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) From 4ae05f372a64ee3d3f94e9db9296c9dec1dd6849 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 6 Aug 2023 16:54:45 +0200 Subject: [PATCH 073/181] removed unused --- scripts/the_ulitmate_test.py | 41 ------------------------------------ 1 file changed, 41 deletions(-) delete mode 100644 scripts/the_ulitmate_test.py diff --git a/scripts/the_ulitmate_test.py b/scripts/the_ulitmate_test.py deleted file mode 100644 index 90394e59502d..000000000000 --- a/scripts/the_ulitmate_test.py +++ /dev/null @@ -1,41 +0,0 @@ -# import torch -# from diffuzz import Diffuzz -# from diffusers import DDPMWuerstchenScheduler - -# torch.manual_seed(42) -# scheduler = DDPMWuerstchenScheduler() -# scheduler.set_timesteps({0.0: 30}) -# diffuzz = Diffuzz() - -# shape = (1, 16, 24, 24) -# x = torch.randn(shape) -# noise = torch.randn(shape) -# t = torch.rand(1) -# t_prev = t - 0.1 - -# output_diffuzz = diffuzz.undiffuse(x, t, t_prev, noise) -# output_scheduler = scheduler.step(noise, timestep=t, prev_t=t_prev, sample=x).prediction -# # scheduler.step(noise, timestep=t, sample=x) - -# print(output_diffuzz.mean()) -# print(output_scheduler.mean()) -# print(output_diffuzz.shape) -# print(output_scheduler.shape) - -from transformers import AutoTokenizer, CLIPTextModel - -device = "cuda" - -def embed_clip(caption, negative_caption="", batch_size=4, device="cuda"): - clip_tokens = clip_tokenizer([caption] * batch_size, truncation=True, padding="max_length", max_length=clip_tokenizer.model_max_length, return_tensors="pt").to(device) - clip_text_embeddings = clip_model(**clip_tokens).last_hidden_state - return clip_text_embeddings - -clip_model = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k").to(device).eval().requires_grad_(False) -clip_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") - -caption = "An armchair in the shape of an avocado" - -emb = embed_clip(caption) - -print(emb) \ No newline at end of file From 71b7fa3b04ff552b7130b58f4f9da67323891e99 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 10:45:24 +0200 Subject: [PATCH 074/181] initial docs --- docs/source/en/api/pipelines/wuerstchen.mdx | 26 ++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/wuerstchen.mdx b/docs/source/en/api/pipelines/wuerstchen.mdx index 23b5cc5f9f5c..58fcbc55c7ea 100644 --- a/docs/source/en/api/pipelines/wuerstchen.mdx +++ b/docs/source/en/api/pipelines/wuerstchen.mdx @@ -1 +1,25 @@ -# Würstchen \ No newline at end of file +# Würstchen + +[Wuerstchen: Efficient Pretraining of Text-to-Image Models](https://huggingface.co/papers/2306.00637) is by Pablo Pernias, Dominic Rampas, and Marc Aubreville. + +The abstract from the paper is: + +*We introduce Wuerstchen, a novel technique for text-to-image synthesis that unites competitive performance with unprecedented cost-effectiveness and ease of training on constrained hardware. Building on recent advancements in machine learning, our approach, which utilizes latent diffusion strategies at strong latent image compression rates, significantly reduces the computational burden, typically associated with state-of-the-art models, while preserving, if not enhancing, the quality of generated images. Wuerstchen achieves notable speed improvements at inference time, thereby rendering real-time applications more viable. One of the key advantages of our method lies in its modest training requirements of only 9,200 GPU hours, slashing the usual costs significantly without compromising the end performance. In a comparison against the state-of-the-art, we found the approach to yield strong competitiveness. This paper opens the door to a new line of research that prioritizes both performance and computational accessibility, hence democratizing the use of sophisticated AI technologies. Through Wuerstchen, we demonstrate a compelling stride forward in the realm of text-to-image synthesis, offering an innovative path to explore in future research.* + +The original codebase can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen). + +## VQDiffusionPipeline +[[autodoc]] VQDiffusionPipeline + - all + - __call__ + +## ImagePipelineOutput +[[autodoc]] pipelines.ImagePipelineOutput + +## WuerstchenGeneratorPipeline +[[autodoc]] WuerstchenGeneratorPipeline + - all + - __call__ + +## WuerstchenGeneratorPipelineOutput +[[autodoc]] pipelines.WuerstchenGeneratorPipelineOutput From 36e9722a10ee95a54a17b34c5bd4ebbcc50feb7e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 10:45:57 +0200 Subject: [PATCH 075/181] rename --- docs/source/en/api/pipelines/{wuerstchen.mdx => wuerstchen.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/source/en/api/pipelines/{wuerstchen.mdx => wuerstchen.md} (100%) diff --git a/docs/source/en/api/pipelines/wuerstchen.mdx b/docs/source/en/api/pipelines/wuerstchen.md similarity index 100% rename from docs/source/en/api/pipelines/wuerstchen.mdx rename to docs/source/en/api/pipelines/wuerstchen.md From f620d833f1deb89770ab83283be3b2bca6071da0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 10:47:58 +0200 Subject: [PATCH 076/181] toc --- docs/source/en/_toctree.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index ac3e02a27f74..608d4ccbd96c 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -288,6 +288,8 @@ title: Versatile Diffusion - local: api/pipelines/vq_diffusion title: VQ Diffusion + - local: api/pipelines/wuerstchen + title: Wuerstchen title: Pipelines - sections: - local: api/schedulers/overview From 61c137ca8c1e9b9edbd6471ace1d2311cfc130da Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 12:50:02 +0200 Subject: [PATCH 077/181] cleanup --- docs/source/en/api/pipelines/wuerstchen.md | 13 +++--- .../wuerstchen/pipeline_wuerstchen.py | 45 +++++-------------- .../wuerstchen/pipeline_wuerstchen_prior.py | 30 ++----------- 3 files changed, 18 insertions(+), 70 deletions(-) diff --git a/docs/source/en/api/pipelines/wuerstchen.md b/docs/source/en/api/pipelines/wuerstchen.md index 58fcbc55c7ea..ba462f79a4a5 100644 --- a/docs/source/en/api/pipelines/wuerstchen.md +++ b/docs/source/en/api/pipelines/wuerstchen.md @@ -8,18 +8,15 @@ The abstract from the paper is: The original codebase can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen). -## VQDiffusionPipeline -[[autodoc]] VQDiffusionPipeline +## WuerstchenPriorPipeline +[[autodoc]] WuerstchenGeneratorPipeline - all - __call__ -## ImagePipelineOutput -[[autodoc]] pipelines.ImagePipelineOutput +## WuerstchenPriorPipelineOutput +[[autodoc]] WuerstchenPriorPipelineOutput ## WuerstchenGeneratorPipeline [[autodoc]] WuerstchenGeneratorPipeline - all - - __call__ - -## WuerstchenGeneratorPipelineOutput -[[autodoc]] pipelines.WuerstchenGeneratorPipelineOutput + - __call__ \ No newline at end of file diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index abe93dcfd513..8cb08479ada0 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -21,14 +21,11 @@ from ...models import VQModelPaella from ...schedulers import DDPMWuerstchenScheduler -from ...utils import BaseOutput, logging, randn_tensor -from ..pipeline_utils import DiffusionPipeline +from ...utils import logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .modules import DiffNeXt, EfficientNetEncoder -# from .diffuzz import Diffuzz - - logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ @@ -38,35 +35,21 @@ >>> from diffusers import WuerstchenPriorPipeline, WuerstchenGeneratorPipeline >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained( - ... "kashif/wuerstchen-prior", torch_dtype=torch.float16 + ... "warp-diffusion/WuerstchenPriorPipeline", torch_dtype=torch.float16 ... ).to("cuda") >>> gen_pipe = WuerstchenGeneratorPipeline.from_pretrain( - ... "kashif/wuerstchen-gen", torch_dtype=torch.float16 + ... "warp-diffusion/WuerstchenGeneratorPipeline", torch_dtype=torch.float16 ... ).to("cuda") >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" >>> prior_output = pipe(prompt) - >>> images = gen_pipe(prior_output.image_embeds, prior_output.text_embeds) + >>> images = gen_pipe(prior_output.image_embeds, prompt=prompt) ``` """ - default_inference_steps_b = {0.0: 12} -@dataclass -class WuerstchenGeneratorPipelineOutput(BaseOutput): - """ - Output class for WuerstchenPriorPipeline. - - Args: - images (`torch.FloatTensor` or `np.ndarray`) - Generated images for text prompt. - """ - - images: Union[torch.FloatTensor, np.ndarray] - - class WuerstchenGeneratorPipeline(DiffusionPipeline): """ Pipeline for generating images from the Wuerstchen model. @@ -189,7 +172,6 @@ def _encode_prompt( uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( @@ -247,7 +229,7 @@ def __call__( guidance_scale: float = 3.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pt", # pt only + output_type: Optional[str] = "pill", return_dict: bool = True, ): device = self._execution_device @@ -292,11 +274,6 @@ def __call__( latents, self.scheduler, ) - # from transformers import AutoTokenizer, CLIPTextModel - # text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to(device) - # tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") - # clip_tokens = tokenizer([""] * latents.size(0), truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt").to(device) - # clip_text_embeddings = text_encoder(**clip_tokens).last_hidden_state.to(dtype) for t in self.progress_bar(timesteps[:-1]): ratio = t.expand(latents.size(0)).to(dtype) @@ -324,14 +301,12 @@ def __call__( ).prev_sample images = self.vqgan.decode(latents).sample.clamp(0, 1) + images = images.permute(0, 2, 3, 1).cpu().numpy() - if output_type not in ["pt", "np"]: - raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") - - if output_type == "np": - images = images.permute(0, 2, 3, 1).cpu().numpy() + if output_type == "pil": + images = self.numpy_to_pil(images) if not return_dict: return images - return WuerstchenGeneratorPipelineOutput(images) + return ImagePipelineOutput(images) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 2d4ece9daeb6..046c8880a4db 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -25,34 +25,25 @@ from .prior import Prior -# from .diffuzz import Diffuzz - - logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch - >>> from diffusers import WuerstchenPriorPipeline, WuerstchenGeneratorPipeline + >>> from diffusers import WuerstchenPriorPipeline >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained( - ... "kashif/wuerstchen-prior", torch_dtype=torch.float16 - ... ).to("cuda") - >>> gen_pipe = WuerstchenGeneratorPipeline.from_pretrain( - ... "kashif/wuerstchen-gen", torch_dtype=torch.float16 + ... "warp-diffusion/WuerstchenPriorPipeline", torch_dtype=torch.float16 ... ).to("cuda") >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" >>> prior_output = pipe(prompt) - >>> images = gen_pipe(prior_output.image_embeds, prior_output.text_embeds) ``` """ default_inference_steps_c = {2 / 3: 20, 0.0: 10} -# default_inference_steps_c = {0.0: 60} -default_inference_steps_b = {0.0: 30} @dataclass @@ -186,7 +177,6 @@ def _encode_prompt( uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( @@ -213,7 +203,7 @@ def __call__( num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pt", # pt only + output_type: Optional[str] = "pt", return_dict: bool = True, ): device = self._execution_device @@ -260,9 +250,6 @@ def __call__( for t in self.progress_bar(timesteps[:-1]): ratio = t.expand(latents.size(0)).to(dtype) - # print(torch.cat([latents] * 2).shape, latents.dtype) - # print(ratio, ratio.shape, ratio.dtype) - # print(text_encoder_hidden_states.shape, text_encoder_hidden_states.dtype) predicted_image_embedding = self.prior( torch.cat([latents] * 2) if do_classifier_free_guidance else latents, r=torch.cat([ratio] * 2) if do_classifier_free_guidance else ratio, @@ -282,20 +269,9 @@ def __call__( generator=generator, ).prev_sample - # t_start = 1.0 - # for t_end, steps in inference_steps.items(): - # steps = torch.linspace(t_start, t_end, steps + 1, dtype=dtype, device=device) - # latents = self.inference_loop( - # latents, steps, text_encoder_hidden_states, do_classifier_free_guidance, guidance_scale, generator - # ) - # t_start = t_end - # normalize the latents latents = latents * 42.0 - 1.0 - if output_type not in ["pt", "np"]: - raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") - if output_type == "np": latents = latents.cpu().numpy() text_encoder_hidden_states = text_encoder_hidden_states.cpu().numpy() From e5127d5ef285c9aea24926e8da56918aac5ff9b3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 12:50:58 +0200 Subject: [PATCH 078/181] remvoe test script --- scripts/wuerstchen_pipeline_test.py | 139 ---------------------------- 1 file changed, 139 deletions(-) delete mode 100644 scripts/wuerstchen_pipeline_test.py diff --git a/scripts/wuerstchen_pipeline_test.py b/scripts/wuerstchen_pipeline_test.py deleted file mode 100644 index 5c13518b5af6..000000000000 --- a/scripts/wuerstchen_pipeline_test.py +++ /dev/null @@ -1,139 +0,0 @@ -import os - -import numpy as np -import torch -import transformers -from PIL import Image - -from diffusers import WuerstchenGeneratorPipeline, WuerstchenPriorPipeline - - -transformers.utils.logging.set_verbosity_error() - - -def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: - """ - Convert a numpy image or a batch of images to a PIL image. - """ - if images.ndim == 3: - images = images[None, ...] - images = (images * 255).round().astype("uint8") - pil_images = [Image.fromarray(image) for image in images] - - return pil_images - - -# effnet_preprocess = torchvision.transforms.Compose( -# [ -# torchvision.transforms.Resize( -# 768, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True -# ), -# torchvision.transforms.CenterCrop(768), -# torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), -# ] -# ) - -# transforms = torchvision.transforms.Compose( -# [ -# torchvision.transforms.ToTensor(), -# torchvision.transforms.Resize(1024), -# torchvision.transforms.RandomCrop(1024), -# ] -# ) -device = "cuda" -dtype = torch.float16 -batch_size = 2 - -# generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\warp-diffusion\\WuerstchenGeneratorPipeline", torch_dtype=dtype) -# generator_pipeline = generator_pipeline.to("cuda") -# text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cuda") -# tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") - -# image = Image.open("C:\\Users\\d6582\\Documents\\ml\\wand\\finetuning\\images\\fernando\\IMG_0352.JPG") -# image = effnet_preprocess(transforms(image).unsqueeze(0).expand(batch_size, -1, -1, -1)).to("cuda").to(dtype) -# print(image.shape) - -# caption = "princess | centered| key visual| intricate| highly detailed| breathtaking beauty| precise lineart| vibrant| comprehensive cinematic| Carne Griffiths| Conrad Roset" -# negative_prompt = "low resolution, low detail, bad quality, blurry" - -# clip_tokens = tokenizer([caption] * image.size(0), truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt").to("cuda") -# clip_text_embeddings = text_encoder(**clip_tokens).last_hidden_state.to(dtype) -# clip_tokens_uncond = tokenizer([negative_prompt] * image.size(0), truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt").to("cuda") -# clip_text_embeddings_uncond = text_encoder(**clip_tokens_uncond).last_hidden_state.to(dtype) - -# image_embeds = generator_pipeline.encode_image(image) -# generator_output = generator_pipeline(image_embeds, clip_text_embeddings, guidance_scale=0.0, output_type="np").images -# images = numpy_to_pil(generator_output) -# os.makedirs("samples", exist_ok=True) -# for i, image in enumerate(images): -# image.save(os.path.join("samples", caption.replace(" ", "_").replace("|", "") + f"_{i}.png")) - -torch.manual_seed(42) - -prior_pipeline = WuerstchenPriorPipeline.from_pretrained("warp-diffusion/WuerstchenPriorPipeline", torch_dtype=dtype) -generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained( - "warp-diffusion/WuerstchenGeneratorPipeline", torch_dtype=dtype -) -prior_pipeline = prior_pipeline.to("cuda") -generator_pipeline = generator_pipeline.to("cuda") -# generator_pipeline.vqgan.to(torch.float16) -# text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu") -# tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") - -# negative_prompt = "low resolution, low detail, bad quality, blurry" -negative_prompt = "bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated, drawing, helmet" -# negative_prompt = "" -caption = ( - "Bee flying out of a glass jar in a green and red leafy basket, glass and lens flare, diffuse lighting elegant" -) -# caption = "princess | centered| key visual| intricate| highly detailed| breathtaking beauty| precise lineart| vibrant| comprehensive cinematic| Carne Griffiths| Conrad Roset" -caption = "An armchair in the shape of an avocado" -# clip_tokens = tokenizer( -# [caption] * batch_size, -# truncation=True, -# padding="max_length", -# max_length=tokenizer.model_max_length, -# return_tensors="pt", -# ) -# clip_text_embeddings = text_encoder(**clip_tokens).last_hidden_state.to(dtype).to(device) -# clip_tokens_uncond = tokenizer( -# [negative_prompt] * batch_size, -# truncation=True, -# padding="max_length", -# max_length=tokenizer.model_max_length, -# return_tensors="pt", -# ) -# clip_text_embeddings_uncond = text_encoder(**clip_tokens_uncond).last_hidden_state.to(dtype).to(device) - -prior_output = prior_pipeline( - caption, - height=1024, - width=1536, - guidance_scale=8.0, - num_images_per_prompt=batch_size, - negative_prompt=negative_prompt, -) -generator_output = generator_pipeline( - predicted_image_embeddings=prior_output.image_embeds, - prompt=caption, - negative_prompt=negative_prompt, - guidance_scale=0.0, - output_type="np", -).images -images = numpy_to_pil(generator_output) -os.makedirs("samples", exist_ok=True) -for i, image in enumerate(images): - image.save(os.path.join("samples", caption.replace(" ", "_").replace("|", "") + f"_{i}.png")) - - -# caption = input("Prompt please: ") -# while caption != "q": -# prior_output = prior_pipeline(caption, num_images_per_prompt=4, negative_prompt=negative_prompt) -# generator_output = generator_pipeline(prior_output.image_embeds, prior_output.text_embeds, output_type="np").images -# images = numpy_to_pil(generator_output) - -# os.makedirs("samples", exist_ok=True) -# for i, image in enumerate(images): -# image.save(os.path.join("samples", caption.replace(" ", "_").replace("|", "") + f"_{i}.png")) - -# caption = input("Prompt please: ") From 61a5ebc5a53e59a5e97a33ab58d3c4bf2c14b38e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 12:53:12 +0200 Subject: [PATCH 079/181] fix-copies --- scripts/convert_wuerstchen_prior.py | 1 - .../wuerstchen/pipeline_wuerstchen.py | 1 - src/diffusers/utils/dummy_pt_objects.py | 26 +++++++++---------- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/scripts/convert_wuerstchen_prior.py b/scripts/convert_wuerstchen_prior.py index db0ee9100b1f..f6d787cdf413 100644 --- a/scripts/convert_wuerstchen_prior.py +++ b/scripts/convert_wuerstchen_prior.py @@ -34,4 +34,3 @@ ) prior_pipeline.save_pretrained("warp-diffusion/WuerstchenPriorPipeline") - diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 8cb08479ada0..25b1ed7ad4d0 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import List, Optional, Union import numpy as np diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index ea04baee9d72..e16499c73a83 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -77,7 +77,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class VQModelPaella(metaclass=DummyObject): +class MultiAdapter(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -92,7 +92,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class MultiAdapter(metaclass=DummyObject): +class PriorTransformer(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -107,7 +107,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class PriorTransformer(metaclass=DummyObject): +class T2IAdapter(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -122,7 +122,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class T2IAdapter(metaclass=DummyObject): +class T5FilmDecoder(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -137,7 +137,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class T5FilmDecoder(metaclass=DummyObject): +class Transformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -152,7 +152,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class Transformer2DModel(metaclass=DummyObject): +class UNet1DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -167,7 +167,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class UNet1DModel(metaclass=DummyObject): +class UNet2DConditionModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -182,7 +182,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class UNet2DConditionModel(metaclass=DummyObject): +class UNet2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -197,7 +197,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class UNet2DModel(metaclass=DummyObject): +class UNet3DConditionModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -212,7 +212,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class UNet3DConditionModel(metaclass=DummyObject): +class VQModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -227,7 +227,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class VQModel(metaclass=DummyObject): +class VQModelPaella(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -600,7 +600,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class DDPMWuerstchenScheduler(metaclass=DummyObject): +class DDPMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -615,7 +615,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class DDPMScheduler(metaclass=DummyObject): +class DDPMWuerstchenScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): From b122dddc02c6979cbdea47a301dfdf81da0fb19e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 7 Aug 2023 10:56:28 +0000 Subject: [PATCH 080/181] fix multi images --- .../wuerstchen/pipeline_wuerstchen.py | 21 ++++++++----------- .../wuerstchen/pipeline_wuerstchen_prior.py | 11 ++-------- 2 files changed, 11 insertions(+), 21 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index abe93dcfd513..19fac3bef8f8 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -245,9 +245,10 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, inference_steps: dict = None, guidance_scale: float = 3.0, + num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pt", # pt only + output_type: Optional[str] = "pil", return_dict: bool = True, ): device = self._execution_device @@ -257,21 +258,14 @@ def __call__( if inference_steps is None: inference_steps = default_inference_steps_b - if negative_prompt is None: - negative_prompt = "" - if isinstance(prompt, str): prompt = [prompt] elif not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - if isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] - elif not isinstance(negative_prompt, list) and negative_prompt is not None: - raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + text_encoder_hidden_states = self._encode_prompt( - prompt, device, predicted_image_embeddings.size(0), do_classifier_free_guidance, negative_prompt + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) - predicted_image_embeddings, text_encoder_hidden_states = self.check_inputs( predicted_image_embeddings, text_encoder_hidden_states, do_classifier_free_guidance, device ) @@ -325,11 +319,14 @@ def __call__( images = self.vqgan.decode(latents).sample.clamp(0, 1) - if output_type not in ["pt", "np"]: - raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `np` and `pil` are supported not output_type={output_type}") if output_type == "np": images = images.permute(0, 2, 3, 1).cpu().numpy() + elif output_type == "pil": + images = images.permute(0, 2, 3, 1).cpu().numpy() + images = self.numpy_to_pil(images) if not return_dict: return images diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 2d4ece9daeb6..e6effe1129a0 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -223,19 +223,12 @@ def __call__( if inference_steps is None: inference_steps = default_inference_steps_c - if negative_prompt is None: - negative_prompt = "" - if isinstance(prompt, str): prompt = [prompt] elif not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - if isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] - elif not isinstance(negative_prompt, list) and negative_prompt is not None: - raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") - + batch_size = len(prompt) if isinstance(prompt, list) else 1 text_encoder_hidden_states = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) @@ -244,7 +237,7 @@ def __call__( latent_height = 128 * (height // 128) // (1024 // 24) latent_width = 128 * (width // 128) // (1024 // 24) num_channels = self.prior.config.c_in - effnet_features_shape = (num_images_per_prompt, num_channels, latent_height, latent_width) + effnet_features_shape = (num_images_per_prompt * batch_size, num_channels, latent_height, latent_width) self.scheduler.set_timesteps(inference_steps, device=device) timesteps = self.scheduler.timesteps From f24ee4747fd964a5d085981d074663e119c613c9 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 13:00:13 +0200 Subject: [PATCH 081/181] remove dup --- src/diffusers/models/unet_2d_blocks.py | 10 +-- src/diffusers/models/vq_model.py | 114 ------------------------- 2 files changed, 1 insertion(+), 123 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 68ad1e2a29ac..6f3037d624f9 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -23,15 +23,7 @@ from .attention import AdaGroupNorm from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from .dual_transformer_2d import DualTransformer2DModel -from .resnet import ( - Downsample2D, - FirDownsample2D, - FirUpsample2D, - KDownsample2D, - KUpsample2D, - ResnetBlock2D, - Upsample2D, -) +from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D from .transformer_2d import Transformer2DModel diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index ab8fa008bd78..393a638d483b 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -20,7 +20,6 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, apply_forward_hook from .modeling_utils import ModelMixin -from .resnet import MixingResidualBlock from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer @@ -37,119 +36,6 @@ class VQEncoderOutput(BaseOutput): latents: torch.FloatTensor -class VQModelPaella(ModelMixin, ConfigMixin): - r"""VQ-VAE model from Paella model. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) - - Parameters: - in_channels (int, *optional*, defaults to 3): Number of channels in the input image. - out_channels (int, *optional*, defaults to 3): Number of channels in the output. - up_down_scale_factor (int, *optional*, defaults to 2): Up and Downscale factor of the input image. - levels (int, *optional*, defaults to 2): Number of levels in the model. - bottleneck_blocks (int, *optional*, defaults to 12): Number of bottleneck blocks in the model. - c_hidden (int, *optional*, defaults to 384): Number of hidden channels in the model. - c_latent (int, *optional*, defaults to 4): Number of latent channels in the model. - codebook_size (int, *optional*, defaults to 8192): Number of codebook vectors in the VQ-VAE. - scale_factor (float, *optional*, defaults to 0.3764): Scaling factor of the latent space. - """ - - @register_to_config - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - up_down_scale_factor: int = 2, - levels: int = 2, - bottleneck_blocks: int = 12, - c_hidden: int = 384, - c_latent: int = 4, - codebook_size: int = 8192, - scale_factor: float = 0.3764, - ): - super().__init__() - - c_levels = [c_hidden // (2**i) for i in reversed(range(levels))] - self.in_block = nn.Sequential( - nn.PixelUnshuffle(up_down_scale_factor), - nn.Conv2d(in_channels * up_down_scale_factor**2, c_levels[0], kernel_size=1), - ) - - down_blocks = [] - for i in range(levels): - if i > 0: - down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) - block = MixingResidualBlock(c_levels[i], c_levels[i] * 4) - down_blocks.append(block) - down_blocks.append( - nn.Sequential( - nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), - nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 - ) - ) - self.down_blocks = nn.Sequential(*down_blocks) - self.vquantizer = VectorQuantizer(codebook_size, vq_embed_dim=c_latent, legacy=False, beta=0.25) - - # Decoder blocks - up_blocks = [nn.Sequential(nn.Conv2d(c_latent, c_levels[-1], kernel_size=1))] - for i in range(levels): - for j in range(bottleneck_blocks if i == 0 else 1): - block = MixingResidualBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) - up_blocks.append(block) - if i < levels - 1: - up_blocks.append( - nn.ConvTranspose2d( - c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1 - ) - ) - self.up_blocks = nn.Sequential(*up_blocks) - self.out_block = nn.Sequential( - nn.Conv2d(c_levels[0], out_channels * up_down_scale_factor**2, kernel_size=1), - nn.PixelShuffle(up_down_scale_factor), - ) - - def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: - h = self.in_block(x) - h = self.down_blocks(h) / self.config.scale_factor - - if not return_dict: - return (h,) - - return VQEncoderOutput(latents=h) - - def decode( - self, h: torch.FloatTensor, force_not_quantize: bool = True, return_dict: bool = True - ) -> Union[DecoderOutput, torch.FloatTensor]: - if not force_not_quantize: - quant, _, _ = self.vquantizer(h * self.config.scale_factor) - else: - quant = h * self.config.scale_factor - - x = self.up_blocks(quant) - dec = self.out_block(x) - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - - def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: - r""" - Args: - sample (`torch.FloatTensor`): Input sample. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`DecoderOutput`] instead of a plain tuple. - """ - x = sample - h = self.encode(x).latents - dec = self.decode(h).sample - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - - class VQModel(ModelMixin, ConfigMixin): r""" A VQ-VAE model for decoding latent representations. From ce23ef782434613f0cad1c5ba34ef034626779ec Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 13:05:03 +0200 Subject: [PATCH 082/181] remove unused modules --- src/diffusers/models/resnet.py | 50 ---------------------- src/diffusers/models/unet_2d_condition.py | 51 ----------------------- 2 files changed, 101 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 52d5794bc3af..0a38ead6ad83 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -674,56 +674,6 @@ def forward(self, inputs): return output -class GlobalResponseNorm(nn.Module): - "Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" - - def __init__(self, dim): - super().__init__() - self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) - self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) - - def forward(self, inputs): - gx = torch.norm(inputs, p=2, dim=(1, 2), keepdim=True) - nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6) - return self.gamma * (inputs * nx) + self.beta + inputs - - -class GlobalResponseResidualBlock(nn.Module): - def __init__(self, inp_channels, channel_skip=None, kernel_size=3, dropout=0.0) -> None: - super().__init__() - - # depthwise - self.depthwise = nn.Conv2d( - inp_channels + channel_skip, - inp_channels, - kernel_size=kernel_size, - padding=kernel_size // 2, - groups=inp_channels, - ) - self.norm = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6) - - # channelwise - self.channelwise = nn.Sequential( - nn.Linear(inp_channels, inp_channels * 4), - nn.GELU(), - GlobalResponseNorm(inp_channels * 4), - nn.Dropout(dropout), - nn.Linear(inp_channels * 4, inp_channels), - ) - - @staticmethod - def _norm(x, norm): - return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - - def forward(self, inputs, inputs_skip=None): - inputs_res = inputs - if inputs_skip is not None: - inputs = torch.cat([inputs, inputs_skip], dim=1) - inputs = self._norm(self.depthwise(inputs), self.norm).permute(0, 2, 3, 1) - inputs = self.channelwise(inputs).permute(0, 3, 1, 2) - return inputs + inputs_res - - class MixingResidualBlock(nn.Module): def __init__(self, inp_channels, c_hidden): super().__init__() diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index eca6011c23c9..fea1b4cd7823 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -59,57 +59,6 @@ class UNet2DConditionOutput(BaseOutput): sample: torch.FloatTensor = None -class LayerNorm2d(nn.LayerNorm): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, x): - return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - - -class PaellaUNet2dConditionalModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - c_in=256, - c_out=256, - num_labels=8192, - c_r=64, - patch_size=2, - c_cond=1024, - c_hidden=[640, 1280, 1280], - nhead=[-1, 16, 16], - blocks=[6, 16, 6], - level_config=["CT", "CTA", "CTA"], - clip_embd=1024, - byt5_embd=1536, - clip_seq_len=4, - kernel_size=3, - dropout=0.1, - self_attn=True, - ): - super().__init__() - if not isinstance(dropout, list): - dropout = [dropout] * len(c_hidden) - - # CONDITIONING - self.byt5_mapper = nn.Linear(byt5_embd, c_cond) - self.clip_mapper = nn.Linear(clip_embd, c_cond * clip_seq_len) - self.clip_image_mapper = nn.Linear(clip_embd, c_cond * clip_seq_len) - self.seq_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) - - self.in_mapper = nn.Sequential( - nn.Embedding(num_labels, c_in), nn.LayerNorm(c_in, elementwise_affine=False, eps=1e-6) - ) - self.embedding = nn.Sequential( - nn.PixelUnshuffle(patch_size), - nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1), - LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), - ) - - class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample From 163cd2b889fd6264484a5c8ea1de60f86e4765e3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 13:06:01 +0200 Subject: [PATCH 083/181] undo changes for debugging --- .gitignore | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.gitignore b/.gitignore index 776df19820d3..358d6b0a9eae 100644 --- a/.gitignore +++ b/.gitignore @@ -174,9 +174,3 @@ tags .ruff_cache wandb -scripts/models/ -scripts/warp-diffusion/ -scripts/warp-diffusion-test/ -scripts/samples/ -scripts/text_encoding_colab.pt -scripts/text_encoding.pt From 59e5f15f1556ed54a1ac98d71da687540df7b20d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 13:06:39 +0200 Subject: [PATCH 084/181] no new line --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 358d6b0a9eae..45602a1f547e 100644 --- a/.gitignore +++ b/.gitignore @@ -173,4 +173,4 @@ tags # ruff .ruff_cache -wandb +wandb \ No newline at end of file From e6f0f75d34f223d114e0b6b4bd89753418f33c7a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 13:09:55 +0200 Subject: [PATCH 085/181] remove dup conversion script --- scripts/convert_wuerstchen_prior.py | 36 ----------------------------- 1 file changed, 36 deletions(-) delete mode 100644 scripts/convert_wuerstchen_prior.py diff --git a/scripts/convert_wuerstchen_prior.py b/scripts/convert_wuerstchen_prior.py deleted file mode 100644 index f6d787cdf413..000000000000 --- a/scripts/convert_wuerstchen_prior.py +++ /dev/null @@ -1,36 +0,0 @@ -import os - -import torch -from transformers import AutoTokenizer, CLIPTextModel - -from diffusers import ( - DDPMWuerstchenScheduler, - WuerstchenPriorPipeline, -) -from diffusers.pipelines.wuerstchen import Prior - - -model_path = "models/" -device = "cpu" - -# Clip Text encoder and tokenizer -text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") -tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") - -# Prior -state_dict = torch.load(os.path.join(model_path, "model_v2_stage_c.pt"), map_location=device) -prior_model = Prior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device) -prior_model.load_state_dict(state_dict["ema_state_dict"]) - -# scheduler -scheduler = DDPMWuerstchenScheduler() - -# Prior pipeline -prior_pipeline = WuerstchenPriorPipeline( - prior=prior_model, - text_encoder=text_encoder, - tokenizer=tokenizer, - scheduler=scheduler, -) - -prior_pipeline.save_pretrained("warp-diffusion/WuerstchenPriorPipeline") From 35e55a7a33f0cf5fb1cafc02a61be635b0e1b7e6 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 13:14:08 +0200 Subject: [PATCH 086/181] fix doc string --- src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py index 9da0f283251c..a53df6039d63 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py +++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py @@ -181,12 +181,11 @@ def step( sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. generator: random number generator. - return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class + return_dict (`bool`): option for returning tuple rather than DDPMWuerstchenSchedulerOutput class Returns: - [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. + [`DDPMWuerstchenSchedulerOutput`] or `tuple`: [`DDPMWuerstchenSchedulerOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ dtype = model_output.dtype From 78cd405cad012acba4dd91d0bf6f5b2e4ccb2bdb Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 13:17:06 +0200 Subject: [PATCH 087/181] cleanup --- src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py index a53df6039d63..aa12e93e87e5 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py +++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py @@ -168,7 +168,6 @@ def step( timestep: int, sample: torch.FloatTensor, generator=None, - # prev_t=None, return_dict: bool = True, ) -> Union[DDPMWuerstchenSchedulerOutput, Tuple]: """ From 3ddee343a384bbad04ab369b46efd40d6285ebeb Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 13:23:22 +0200 Subject: [PATCH 088/181] pass default args --- .../pipelines/wuerstchen/pipeline_wuerstchen.py | 7 +------ .../pipelines/wuerstchen/pipeline_wuerstchen_prior.py | 9 +-------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 64eb04a377d0..478bbab1c75a 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -46,8 +46,6 @@ ``` """ -default_inference_steps_b = {0.0: 12} - class WuerstchenGeneratorPipeline(DiffusionPipeline): """ @@ -224,7 +222,7 @@ def __call__( predicted_image_embeddings: torch.Tensor, prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, - inference_steps: dict = None, + inference_steps: dict[float, int] = {0.0: 12}, guidance_scale: float = 3.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -236,9 +234,6 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 - if inference_steps is None: - inference_steps = default_inference_steps_b - if isinstance(prompt, str): prompt = [prompt] elif not isinstance(prompt, list): diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index c173bbac84e5..2b24f1070af1 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -43,9 +43,6 @@ """ -default_inference_steps_c = {2 / 3: 20, 0.0: 10} - - @dataclass class WuerstchenPriorPipelineOutput(BaseOutput): """ @@ -197,7 +194,7 @@ def __call__( prompt: Union[str, List[str]] = None, height: int = 1024, width: int = 1024, - inference_steps: dict = None, + inference_steps: dict[float, int] = {2 / 3: 20, 0.0: 10}, guidance_scale: float = 8.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -207,12 +204,8 @@ def __call__( return_dict: bool = True, ): device = self._execution_device - do_classifier_free_guidance = guidance_scale > 1.0 - if inference_steps is None: - inference_steps = default_inference_steps_c - if isinstance(prompt, str): prompt = [prompt] elif not isinstance(prompt, list): From fca022a74350a98abf90636a10dca6cfe63276ed Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 13:27:22 +0200 Subject: [PATCH 089/181] dup permute --- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 478bbab1c75a..063f4a2bbc6c 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -79,7 +79,6 @@ def __init__( efficient_net: EfficientNetEncoder, ) -> None: super().__init__() - self.multiple = 128 self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, @@ -289,7 +288,6 @@ def __call__( ).prev_sample images = self.vqgan.decode(latents).sample.clamp(0, 1) - images = images.permute(0, 2, 3, 1).cpu().numpy() if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `np` and `pil` are supported not output_type={output_type}") From 903ba6f433f50f50dcf7767792cf649db83851e5 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 13:44:07 +0200 Subject: [PATCH 090/181] fix some tests --- .../pipelines/wuerstchen/pipeline_wuerstchen.py | 4 ++-- .../pipelines/wuerstchen/pipeline_wuerstchen_prior.py | 4 ++-- tests/pipelines/wuerstchen/test_wuerstchen_prior.py | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 063f4a2bbc6c..f6a312d3bbba 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -221,7 +221,7 @@ def __call__( predicted_image_embeddings: torch.Tensor, prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, - inference_steps: dict[float, int] = {0.0: 12}, + num_inference_steps: dict[float, int] = {0.0: 12}, guidance_scale: float = 3.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -250,7 +250,7 @@ def __call__( latent_width = int(predicted_image_embeddings.size(3) * (256 / 24)) effnet_features_shape = (predicted_image_embeddings.size(0), 4, latent_height, latent_width) - self.scheduler.set_timesteps(inference_steps, device=device) + self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps latents = self.prepare_latents( diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 2b24f1070af1..06cc4c817a7d 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -194,7 +194,7 @@ def __call__( prompt: Union[str, List[str]] = None, height: int = 1024, width: int = 1024, - inference_steps: dict[float, int] = {2 / 3: 20, 0.0: 10}, + num_inference_steps: dict[float, int] = {2 / 3: 20, 0.0: 10}, guidance_scale: float = 8.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -222,7 +222,7 @@ def __call__( num_channels = self.prior.config.c_in effnet_features_shape = (num_images_per_prompt * batch_size, num_channels, latent_height, latent_width) - self.scheduler.set_timesteps(inference_steps, device=device) + self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps latents = self.prepare_latents( diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py index c315f9fe63eb..3d3d42e9702f 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py @@ -23,7 +23,7 @@ CLIPTokenizer, ) -from diffusers import WuerstchenPriorPipeline, DDPMScheduler +from diffusers import DDPMWuerstchenScheduler, WuerstchenPriorPipeline from diffusers.pipelines.wuerstchen import Prior from diffusers.utils import torch_device from diffusers.utils.testing_utils import enable_full_determinism, skip_mps @@ -95,7 +95,7 @@ def dummy_prior(self): "c_in": 2, "c": 8, "depth": 2, - "c_cond": 37, + "c_cond": 32, "c_r": 8, "nhead": 2, "latent_size": (2, 2), @@ -109,7 +109,7 @@ def get_dummy_components(self): text_encoder = self.dummy_text_encoder tokenizer = self.dummy_tokenizer - scheduler = DDPMScheduler() + scheduler = DDPMWuerstchenScheduler() components = { "prior": prior, @@ -129,12 +129,12 @@ def get_dummy_inputs(self, device, seed=0): "prompt": "horse", "generator": generator, "guidance_scale": 4.0, - "num_inference_steps": 2, + "num_inference_steps": {0.0: 2}, "output_type": "np", } return inputs - def test_kandinsky_prior(self): + def test_wuerstchen_prior(self): device = "cpu" components = self.get_dummy_components() From 67eaff6a05219171c22069db8f4a9fc1b9cf84e1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 13:55:25 +0200 Subject: [PATCH 091/181] fix prepare_latents --- .../pipelines/wuerstchen/pipeline_wuerstchen.py | 11 ++--------- .../pipelines/wuerstchen/pipeline_wuerstchen_prior.py | 11 ++--------- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index f6a312d3bbba..dafb98e5ef73 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -89,7 +89,7 @@ def __init__( ) self.register_to_config() - def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + def prepare_latents(self, shape, dtype, device, generator, latents): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: @@ -253,14 +253,7 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps - latents = self.prepare_latents( - effnet_features_shape, - dtype, - device, - generator, - latents, - self.scheduler, - ) + latents = self.prepare_latents(effnet_features_shape, dtype, device, generator, latents) for t in self.progress_bar(timesteps[:-1]): ratio = t.expand(latents.size(0)).to(dtype) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 06cc4c817a7d..7eb865f54fdc 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -95,7 +95,7 @@ def __init__( ) self.register_to_config() - def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + def prepare_latents(self, shape, dtype, device, generator, latents): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: @@ -225,14 +225,7 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps - latents = self.prepare_latents( - effnet_features_shape, - dtype, - device, - generator, - latents, - self.scheduler, - ) + latents = self.prepare_latents(effnet_features_shape, dtype, device, generator, latents) for t in self.progress_bar(timesteps[:-1]): ratio = t.expand(latents.size(0)).to(dtype) From 373113156f3e5ef300578231ac5c0abc7abc393f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 14:20:19 +0200 Subject: [PATCH 092/181] move Prior class to modules --- .../pipelines/wuerstchen/__init__.py | 3 +- src/diffusers/pipelines/wuerstchen/modules.py | 50 +++++++ .../wuerstchen/pipeline_wuerstchen.py | 19 ++- .../wuerstchen/pipeline_wuerstchen_prior.py | 21 ++- src/diffusers/pipelines/wuerstchen/prior.py | 136 ------------------ 5 files changed, 88 insertions(+), 141 deletions(-) delete mode 100644 src/diffusers/pipelines/wuerstchen/prior.py diff --git a/src/diffusers/pipelines/wuerstchen/__init__.py b/src/diffusers/pipelines/wuerstchen/__init__.py index 54a7dee3b533..93409b0f630e 100644 --- a/src/diffusers/pipelines/wuerstchen/__init__.py +++ b/src/diffusers/pipelines/wuerstchen/__init__.py @@ -2,7 +2,6 @@ if is_transformers_available() and is_torch_available(): - from .modules import DiffNeXt, EfficientNetEncoder + from .modules import DiffNeXt, Prior, EfficientNetEncoder from .pipeline_wuerstchen import WuerstchenGeneratorPipeline from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline - from .prior import Prior diff --git a/src/diffusers/pipelines/wuerstchen/modules.py b/src/diffusers/pipelines/wuerstchen/modules.py index 715a0b0a221e..8b32e6b886a6 100644 --- a/src/diffusers/pipelines/wuerstchen/modules.py +++ b/src/diffusers/pipelines/wuerstchen/modules.py @@ -337,3 +337,53 @@ def forward(self, x, r, effnet, clip=None, x_cat=None, eps=1e-3, return_noise=Tr return (x_in - a) / b else: return a, b + + +class Prior(ModelMixin, ConfigMixin): + @register_to_config + def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, latent_size=(12, 12), dropout=0.1): + super().__init__() + self.c_r = c_r + self.projection = nn.Conv2d(c_in, c, kernel_size=1) + self.cond_mapper = nn.Sequential( + nn.Linear(c_cond, c), + nn.LeakyReLU(0.2), + nn.Linear(c, c), + ) + + self.blocks = nn.ModuleList() + for _ in range(depth): + self.blocks.append(ResBlock(c, dropout=dropout)) + self.blocks.append(TimestepBlock(c, c_r)) + self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout)) + self.out = nn.Sequential( + LayerNorm2d(c, elementwise_affine=False, eps=1e-6), + nn.Conv2d(c, c_in * 2, kernel_size=1), + ) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode="constant") + return emb.to(dtype=r.dtype) + + def forward(self, x, r, c): + x_in = x + x = self.projection(x) + c_embed = self.cond_mapper(c) + r_embed = self.gen_r_embedding(r) + for block in self.blocks: + if isinstance(block, AttnBlock): + x = block(x, c_embed) + elif isinstance(block, TimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + a, b = self.out(x).chunk(2, dim=1) + # denoised = a / (1-(1-b).pow(2)).sqrt() + return (x_in - a) / ((1 - b).abs() + 1e-5) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index dafb98e5ef73..978c840e6424 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -20,7 +20,7 @@ from ...models import VQModelPaella from ...schedulers import DDPMWuerstchenScheduler -from ...utils import logging, randn_tensor +from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .modules import DiffNeXt, EfficientNetEncoder @@ -99,6 +99,23 @@ def prepare_latents(self, shape, dtype, device, generator, latents): return latents + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, generator, + text_encoder, vqgan and efficient_net have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.generator, self.text_encoder, self.vqgan, self.efficient_net]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + def _encode_prompt( self, prompt, diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 7eb865f54fdc..955b775770f8 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -20,9 +20,9 @@ from transformers import CLIPTextModel, CLIPTokenizer from ...schedulers import DDPMWuerstchenScheduler -from ...utils import BaseOutput, logging, randn_tensor +from ...utils import BaseOutput, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline -from .prior import Prior +from .modules import Prior logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -95,6 +95,23 @@ def __init__( ) self.register_to_config() + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, prior and + text_encoder have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to + GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.text_encoder, self.prior]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + def prepare_latents(self, shape, dtype, device, generator, latents): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) diff --git a/src/diffusers/pipelines/wuerstchen/prior.py b/src/diffusers/pipelines/wuerstchen/prior.py deleted file mode 100644 index f3dae497afdc..000000000000 --- a/src/diffusers/pipelines/wuerstchen/prior.py +++ /dev/null @@ -1,136 +0,0 @@ -import math - -import torch -import torch.nn as nn - -from diffusers.configuration_utils import ConfigMixin, register_to_config - -from ...models.modeling_utils import ModelMixin - - -class LayerNorm2d(nn.LayerNorm): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, x): - return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - - -class TimestepBlock(nn.Module): - def __init__(self, c, c_timestep): - super().__init__() - self.mapper = nn.Linear(c_timestep, c * 2) - - def forward(self, x, t): - a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1) - return x * (1 + a) + b - - -class Attention2D(nn.Module): - def __init__(self, c, nhead, dropout=0.0): - super().__init__() - self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) - - def forward(self, x, kv, self_attn=False): - orig_shape = x.shape - x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 - if self_attn: - kv = torch.cat([x, kv], dim=1) - x = self.attn(x, kv, kv, need_weights=False)[0] - x = x.permute(0, 2, 1).view(*orig_shape) - return x - - -class ResBlock(nn.Module): - def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): - super().__init__() - self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) - self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) - self.channelwise = nn.Sequential( - nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c) - ) - - def forward(self, x, x_skip=None): - x_res = x - if x_skip is not None: - x = torch.cat([x, x_skip], dim=1) - x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1) - x = self.channelwise(x).permute(0, 3, 1, 2) - return x + x_res - - -# from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 -class GlobalResponseNorm(nn.Module): - def __init__(self, dim): - super().__init__() - self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) - self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) - - def forward(self, x): - Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) - Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) - return self.gamma * (x * Nx) + self.beta + x - - -class AttnBlock(nn.Module): - def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): - super().__init__() - self.self_attn = self_attn - self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) - self.attention = Attention2D(c, nhead, dropout) - self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c)) - - def forward(self, x, kv): - kv = self.kv_mapper(kv) - x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) - return x - - -class Prior(ModelMixin, ConfigMixin): - @register_to_config - def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, latent_size=(12, 12), dropout=0.1): - super().__init__() - self.c_r = c_r - self.projection = nn.Conv2d(c_in, c, kernel_size=1) - self.cond_mapper = nn.Sequential( - nn.Linear(c_cond, c), - nn.LeakyReLU(0.2), - nn.Linear(c, c), - ) - - self.blocks = nn.ModuleList() - for _ in range(depth): - self.blocks.append(ResBlock(c, dropout=dropout)) - self.blocks.append(TimestepBlock(c, c_r)) - self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout)) - self.out = nn.Sequential( - LayerNorm2d(c, elementwise_affine=False, eps=1e-6), - nn.Conv2d(c, c_in * 2, kernel_size=1), - ) - - def gen_r_embedding(self, r, max_positions=10000): - r = r * max_positions - half_dim = self.c_r // 2 - emb = math.log(max_positions) / (half_dim - 1) - emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() - emb = r[:, None] * emb[None, :] - emb = torch.cat([emb.sin(), emb.cos()], dim=1) - if self.c_r % 2 == 1: # zero pad - emb = nn.functional.pad(emb, (0, 1), mode="constant") - return emb.to(dtype=r.dtype) - - def forward(self, x, r, c): - x_in = x - x = self.projection(x) - c_embed = self.cond_mapper(c) - r_embed = self.gen_r_embedding(r) - for block in self.blocks: - if isinstance(block, AttnBlock): - x = block(x, c_embed) - elif isinstance(block, TimestepBlock): - x = block(x, r_embed) - else: - x = block(x) - a, b = self.out(x).chunk(2, dim=1) - # denoised = a / (1-(1-b).pow(2)).sqrt() - return (x_in - a) / ((1 - b).abs() + 1e-5) From 09ca25b7c0a252dae75a20a279044e10789bb1f9 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 14:41:25 +0200 Subject: [PATCH 093/181] offload only the text encoder and vqgan --- src/diffusers/pipelines/wuerstchen/__init__.py | 2 +- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 8 ++++---- .../pipelines/wuerstchen/pipeline_wuerstchen_prior.py | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/__init__.py b/src/diffusers/pipelines/wuerstchen/__init__.py index 93409b0f630e..877db65d8c9b 100644 --- a/src/diffusers/pipelines/wuerstchen/__init__.py +++ b/src/diffusers/pipelines/wuerstchen/__init__.py @@ -2,6 +2,6 @@ if is_transformers_available() and is_torch_available(): - from .modules import DiffNeXt, Prior, EfficientNetEncoder + from .modules import DiffNeXt, EfficientNetEncoder, Prior from .pipeline_wuerstchen import WuerstchenGeneratorPipeline from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 978c840e6424..d7930cacfe28 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -101,9 +101,9 @@ def prepare_latents(self, shape, dtype, device, generator, latents): def enable_sequential_cpu_offload(self, gpu_id=0): r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, generator, - text_encoder, vqgan and efficient_net have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, text_encoder, + vqgan and efficient_net have their state dicts saved to CPU and then are moved to a `torch.device('meta') and + loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload @@ -112,7 +112,7 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") - for cpu_offloaded_model in [self.generator, self.text_encoder, self.vqgan, self.efficient_net]: + for cpu_offloaded_model in [self.text_encoder, self.vqgan, self.efficient_net]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 955b775770f8..dd905b13c0b6 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -97,9 +97,9 @@ def __init__( def enable_sequential_cpu_offload(self, gpu_id=0): r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, prior and - text_encoder have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to - GPU only when their specific submodule has its `forward` method called. + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the text_encoder + have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when + their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload @@ -108,7 +108,7 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") - for cpu_offloaded_model in [self.text_encoder, self.prior]: + for cpu_offloaded_model in [self.text_encoder]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) From 8f7a74a63626451ace1e4091a6f9378ea727de57 Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Mon, 7 Aug 2023 14:56:49 +0200 Subject: [PATCH 094/181] fix resolution calculation for prior --- scripts/wuerstchen_pipeline_test.py | 107 ++++++++---------- .../wuerstchen/pipeline_wuerstchen_prior.py | 4 +- 2 files changed, 51 insertions(+), 60 deletions(-) diff --git a/scripts/wuerstchen_pipeline_test.py b/scripts/wuerstchen_pipeline_test.py index 5c13518b5af6..70f1d7f9287e 100644 --- a/scripts/wuerstchen_pipeline_test.py +++ b/scripts/wuerstchen_pipeline_test.py @@ -42,7 +42,7 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: # ) device = "cuda" dtype = torch.float16 -batch_size = 2 +batch_size = 1 # generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\warp-diffusion\\WuerstchenGeneratorPipeline", torch_dtype=dtype) # generator_pipeline = generator_pipeline.to("cuda") @@ -76,64 +76,55 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: ) prior_pipeline = prior_pipeline.to("cuda") generator_pipeline = generator_pipeline.to("cuda") -# generator_pipeline.vqgan.to(torch.float16) -# text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu") -# tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") - -# negative_prompt = "low resolution, low detail, bad quality, blurry" -negative_prompt = "bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated, drawing, helmet" -# negative_prompt = "" +negative_prompt = "bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated" caption = ( - "Bee flying out of a glass jar in a green and red leafy basket, glass and lens flare, diffuse lighting elegant" + "A captivating artwork of a mysterious stone golem" ) # caption = "princess | centered| key visual| intricate| highly detailed| breathtaking beauty| precise lineart| vibrant| comprehensive cinematic| Carne Griffiths| Conrad Roset" -caption = "An armchair in the shape of an avocado" -# clip_tokens = tokenizer( -# [caption] * batch_size, -# truncation=True, -# padding="max_length", -# max_length=tokenizer.model_max_length, -# return_tensors="pt", -# ) -# clip_text_embeddings = text_encoder(**clip_tokens).last_hidden_state.to(dtype).to(device) -# clip_tokens_uncond = tokenizer( -# [negative_prompt] * batch_size, -# truncation=True, -# padding="max_length", -# max_length=tokenizer.model_max_length, -# return_tensors="pt", + +# prior_output = prior_pipeline( +# caption, +# height=1024, +# width=1024, +# guidance_scale=8.0, +# num_images_per_prompt=batch_size, +# negative_prompt=negative_prompt, # ) -# clip_text_embeddings_uncond = text_encoder(**clip_tokens_uncond).last_hidden_state.to(dtype).to(device) - -prior_output = prior_pipeline( - caption, - height=1024, - width=1536, - guidance_scale=8.0, - num_images_per_prompt=batch_size, - negative_prompt=negative_prompt, -) -generator_output = generator_pipeline( - predicted_image_embeddings=prior_output.image_embeds, - prompt=caption, - negative_prompt=negative_prompt, - guidance_scale=0.0, - output_type="np", -).images -images = numpy_to_pil(generator_output) -os.makedirs("samples", exist_ok=True) -for i, image in enumerate(images): - image.save(os.path.join("samples", caption.replace(" ", "_").replace("|", "") + f"_{i}.png")) - - -# caption = input("Prompt please: ") -# while caption != "q": -# prior_output = prior_pipeline(caption, num_images_per_prompt=4, negative_prompt=negative_prompt) -# generator_output = generator_pipeline(prior_output.image_embeds, prior_output.text_embeds, output_type="np").images -# images = numpy_to_pil(generator_output) - -# os.makedirs("samples", exist_ok=True) -# for i, image in enumerate(images): -# image.save(os.path.join("samples", caption.replace(" ", "_").replace("|", "") + f"_{i}.png")) - -# caption = input("Prompt please: ") +# generator_output = generator_pipeline( +# predicted_image_embeddings=prior_output.image_embeds, +# prompt=caption, +# negative_prompt=negative_prompt, +# guidance_scale=0.0, +# output_type="np", +# ).images +# images = numpy_to_pil(generator_output) +# os.makedirs("samples", exist_ok=True) +# for i, image in enumerate(images): +# image.save(os.path.join("samples", caption.replace(" ", "_").replace("|", "") + f"_{i}.png")) + + +caption = input("Prompt please: ") +while caption != "q": + prior_output = prior_pipeline( + caption, + height=1024, + width=4096, + guidance_scale=8.0, + num_images_per_prompt=batch_size, + negative_prompt=negative_prompt, + ) + generator_output = generator_pipeline( + predicted_image_embeddings=prior_output.image_embeds, + prompt=caption, + negative_prompt=negative_prompt, + guidance_scale=0.0, + output_type="np", + ).images + print("Done Sampling") + images = numpy_to_pil(generator_output) + + os.makedirs("samples", exist_ok=True) + for i, image in enumerate(images): + image.save(os.path.join("samples", caption.replace(" ", "_").replace("|", "") + f"_{i}.png")) + + caption = input("Prompt please: ") diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 2d4ece9daeb6..5963cb3cf5f6 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -241,8 +241,8 @@ def __call__( ) dtype = text_encoder_hidden_states.dtype - latent_height = 128 * (height // 128) // (1024 // 24) - latent_width = 128 * (width // 128) // (1024 // 24) + latent_height = int(128 * (height / 128) / (1024 / 24)) + latent_width = int(128 * (width / 128) / (1024 / 24)) num_channels = self.prior.config.c_in effnet_features_shape = (num_images_per_prompt, num_channels, latent_height, latent_width) From cc80e2b0bf2c1b0e968d98eb8cec4785dc1618fa Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Mon, 7 Aug 2023 14:58:43 +0200 Subject: [PATCH 095/181] nip --- .gitignore | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 45602a1f547e..2b81433477dd 100644 --- a/.gitignore +++ b/.gitignore @@ -173,4 +173,10 @@ tags # ruff .ruff_cache -wandb \ No newline at end of file +wandb +scripts/models/ +scripts/warp-diffusion/ +scripts/warp-diffusion-test/ +scripts/samples/ +scripts/text_encoding_colab.pt +scripts/text_encoding.pt \ No newline at end of file From f74f6885b245bf9a6327fd2e0e1d5bc0302facf8 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 15:30:06 +0200 Subject: [PATCH 096/181] removed testing script --- scripts/wuerstchen_pipeline_test.py | 130 ---------------------------- 1 file changed, 130 deletions(-) delete mode 100644 scripts/wuerstchen_pipeline_test.py diff --git a/scripts/wuerstchen_pipeline_test.py b/scripts/wuerstchen_pipeline_test.py deleted file mode 100644 index 70f1d7f9287e..000000000000 --- a/scripts/wuerstchen_pipeline_test.py +++ /dev/null @@ -1,130 +0,0 @@ -import os - -import numpy as np -import torch -import transformers -from PIL import Image - -from diffusers import WuerstchenGeneratorPipeline, WuerstchenPriorPipeline - - -transformers.utils.logging.set_verbosity_error() - - -def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: - """ - Convert a numpy image or a batch of images to a PIL image. - """ - if images.ndim == 3: - images = images[None, ...] - images = (images * 255).round().astype("uint8") - pil_images = [Image.fromarray(image) for image in images] - - return pil_images - - -# effnet_preprocess = torchvision.transforms.Compose( -# [ -# torchvision.transforms.Resize( -# 768, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True -# ), -# torchvision.transforms.CenterCrop(768), -# torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), -# ] -# ) - -# transforms = torchvision.transforms.Compose( -# [ -# torchvision.transforms.ToTensor(), -# torchvision.transforms.Resize(1024), -# torchvision.transforms.RandomCrop(1024), -# ] -# ) -device = "cuda" -dtype = torch.float16 -batch_size = 1 - -# generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\warp-diffusion\\WuerstchenGeneratorPipeline", torch_dtype=dtype) -# generator_pipeline = generator_pipeline.to("cuda") -# text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cuda") -# tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") - -# image = Image.open("C:\\Users\\d6582\\Documents\\ml\\wand\\finetuning\\images\\fernando\\IMG_0352.JPG") -# image = effnet_preprocess(transforms(image).unsqueeze(0).expand(batch_size, -1, -1, -1)).to("cuda").to(dtype) -# print(image.shape) - -# caption = "princess | centered| key visual| intricate| highly detailed| breathtaking beauty| precise lineart| vibrant| comprehensive cinematic| Carne Griffiths| Conrad Roset" -# negative_prompt = "low resolution, low detail, bad quality, blurry" - -# clip_tokens = tokenizer([caption] * image.size(0), truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt").to("cuda") -# clip_text_embeddings = text_encoder(**clip_tokens).last_hidden_state.to(dtype) -# clip_tokens_uncond = tokenizer([negative_prompt] * image.size(0), truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt").to("cuda") -# clip_text_embeddings_uncond = text_encoder(**clip_tokens_uncond).last_hidden_state.to(dtype) - -# image_embeds = generator_pipeline.encode_image(image) -# generator_output = generator_pipeline(image_embeds, clip_text_embeddings, guidance_scale=0.0, output_type="np").images -# images = numpy_to_pil(generator_output) -# os.makedirs("samples", exist_ok=True) -# for i, image in enumerate(images): -# image.save(os.path.join("samples", caption.replace(" ", "_").replace("|", "") + f"_{i}.png")) - -torch.manual_seed(42) - -prior_pipeline = WuerstchenPriorPipeline.from_pretrained("warp-diffusion/WuerstchenPriorPipeline", torch_dtype=dtype) -generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained( - "warp-diffusion/WuerstchenGeneratorPipeline", torch_dtype=dtype -) -prior_pipeline = prior_pipeline.to("cuda") -generator_pipeline = generator_pipeline.to("cuda") -negative_prompt = "bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated" -caption = ( - "A captivating artwork of a mysterious stone golem" -) -# caption = "princess | centered| key visual| intricate| highly detailed| breathtaking beauty| precise lineart| vibrant| comprehensive cinematic| Carne Griffiths| Conrad Roset" - -# prior_output = prior_pipeline( -# caption, -# height=1024, -# width=1024, -# guidance_scale=8.0, -# num_images_per_prompt=batch_size, -# negative_prompt=negative_prompt, -# ) -# generator_output = generator_pipeline( -# predicted_image_embeddings=prior_output.image_embeds, -# prompt=caption, -# negative_prompt=negative_prompt, -# guidance_scale=0.0, -# output_type="np", -# ).images -# images = numpy_to_pil(generator_output) -# os.makedirs("samples", exist_ok=True) -# for i, image in enumerate(images): -# image.save(os.path.join("samples", caption.replace(" ", "_").replace("|", "") + f"_{i}.png")) - - -caption = input("Prompt please: ") -while caption != "q": - prior_output = prior_pipeline( - caption, - height=1024, - width=4096, - guidance_scale=8.0, - num_images_per_prompt=batch_size, - negative_prompt=negative_prompt, - ) - generator_output = generator_pipeline( - predicted_image_embeddings=prior_output.image_embeds, - prompt=caption, - negative_prompt=negative_prompt, - guidance_scale=0.0, - output_type="np", - ).images - print("Done Sampling") - images = numpy_to_pil(generator_output) - - os.makedirs("samples", exist_ok=True) - for i, image in enumerate(images): - image.save(os.path.join("samples", caption.replace(" ", "_").replace("|", "") + f"_{i}.png")) - - caption = input("Prompt please: ") From 829a394d6a4f534b14cb6b4618d4e365ba7ec0b3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 15:30:34 +0200 Subject: [PATCH 097/181] fix shape --- tests/pipelines/wuerstchen/test_wuerstchen_prior.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py index 3d3d42e9702f..07db13634136 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py @@ -152,10 +152,10 @@ def test_wuerstchen_prior(self): return_dict=False, )[0] - image_slice = image[0, -10:] - image_from_tuple_slice = image_from_tuple[0, -10:] + image_slice = image[0, 0, 0, -10:] + image_from_tuple_slice = image_from_tuple[0, 0, 0, -10:] - assert image.shape == (1, 32) + assert image.shape == (1, 2, 24, 24) expected_slice = np.array( [-0.0532, 1.7120, 0.3656, -1.0852, -0.8946, -1.1756, 0.4348, 0.2482, 0.5146, -0.1156] From 170180a2eda4a869697930c8d94eee62362b69e5 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 15:37:38 +0200 Subject: [PATCH 098/181] fix argument to set_timesteps --- src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py index aa12e93e87e5..9821be44ecfe 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py +++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py @@ -136,15 +136,14 @@ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = def set_timesteps( self, - inference_steps: Optional[dict] = None, + num_inference_steps: dict[float, int], device: Union[str, torch.device] = None, - timesteps: Optional[List[int]] = None, ): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: - num_inference_steps (`Optional[int]`): + num_inference_steps (`dict[float, int]`): the number of diffusion steps used when generating samples with a pre-trained model. If passed, then `timesteps` must be `None`. device (`str` or `torch.device`, optional): @@ -152,7 +151,7 @@ def set_timesteps( """ timesteps = None t_start = 1.0 - for t_end, steps in inference_steps.items(): + for t_end, steps in num_inference_steps.items(): steps = torch.linspace(t_start, t_end, steps + 1, device=device) t_start = t_end if timesteps is None: From 687de06ca56afd259725871d6623f8a6b0bb248e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 16:32:08 +0200 Subject: [PATCH 099/181] do not change .gitignore --- .gitignore | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 2b81433477dd..45602a1f547e 100644 --- a/.gitignore +++ b/.gitignore @@ -173,10 +173,4 @@ tags # ruff .ruff_cache -wandb -scripts/models/ -scripts/warp-diffusion/ -scripts/warp-diffusion-test/ -scripts/samples/ -scripts/text_encoding_colab.pt -scripts/text_encoding.pt \ No newline at end of file +wandb \ No newline at end of file From 0ca12eed80951e2f34e9297fb774341d51348c8a Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Mon, 7 Aug 2023 16:35:00 +0200 Subject: [PATCH 100/181] fix resolution calculations + readme --- scripts/wuerstchen_pipeline_test.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/scripts/wuerstchen_pipeline_test.py b/scripts/wuerstchen_pipeline_test.py index 70f1d7f9287e..e022048529c3 100644 --- a/scripts/wuerstchen_pipeline_test.py +++ b/scripts/wuerstchen_pipeline_test.py @@ -42,7 +42,7 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: # ) device = "cuda" dtype = torch.float16 -batch_size = 1 +batch_size = 2 # generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained("C:\\Users\\d6582\\Documents\\ml\\diffusers\\scripts\\warp-diffusion\\WuerstchenGeneratorPipeline", torch_dtype=dtype) # generator_pipeline = generator_pipeline.to("cuda") @@ -108,7 +108,7 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: prior_output = prior_pipeline( caption, height=1024, - width=4096, + width=2048, guidance_scale=8.0, num_images_per_prompt=batch_size, negative_prompt=negative_prompt, @@ -117,14 +117,15 @@ def numpy_to_pil(images: np.ndarray) -> list[Image.Image]: predicted_image_embeddings=prior_output.image_embeds, prompt=caption, negative_prompt=negative_prompt, + num_images_per_prompt=batch_size, guidance_scale=0.0, - output_type="np", + output_type="pil", ).images print("Done Sampling") - images = numpy_to_pil(generator_output) + # images = numpy_to_pil(generator_output) os.makedirs("samples", exist_ok=True) - for i, image in enumerate(images): + for i, image in enumerate(generator_output): image.save(os.path.join("samples", caption.replace(" ", "_").replace("|", "") + f"_{i}.png")) caption = input("Prompt please: ") From 433bdedf511a720fb76ba0a955a806191138411c Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Mon, 7 Aug 2023 16:35:24 +0200 Subject: [PATCH 101/181] resolution calculation fix + readme --- docs/source/en/api/pipelines/wuerstchen.md | 55 ++++++++++++++++++- .../wuerstchen/pipeline_wuerstchen.py | 19 +++++-- .../wuerstchen/pipeline_wuerstchen_prior.py | 13 ++++- 3 files changed, 76 insertions(+), 11 deletions(-) diff --git a/docs/source/en/api/pipelines/wuerstchen.md b/docs/source/en/api/pipelines/wuerstchen.md index ba462f79a4a5..1a272ba25529 100644 --- a/docs/source/en/api/pipelines/wuerstchen.md +++ b/docs/source/en/api/pipelines/wuerstchen.md @@ -1,12 +1,61 @@ # Würstchen -[Wuerstchen: Efficient Pretraining of Text-to-Image Models](https://huggingface.co/papers/2306.00637) is by Pablo Pernias, Dominic Rampas, and Marc Aubreville. +[Würstchen: Efficient Pretraining of Text-to-Image Models](https://huggingface.co/papers/2306.00637) is by Pablo Pernias, Dominic Rampas, and Marc Aubreville. The abstract from the paper is: -*We introduce Wuerstchen, a novel technique for text-to-image synthesis that unites competitive performance with unprecedented cost-effectiveness and ease of training on constrained hardware. Building on recent advancements in machine learning, our approach, which utilizes latent diffusion strategies at strong latent image compression rates, significantly reduces the computational burden, typically associated with state-of-the-art models, while preserving, if not enhancing, the quality of generated images. Wuerstchen achieves notable speed improvements at inference time, thereby rendering real-time applications more viable. One of the key advantages of our method lies in its modest training requirements of only 9,200 GPU hours, slashing the usual costs significantly without compromising the end performance. In a comparison against the state-of-the-art, we found the approach to yield strong competitiveness. This paper opens the door to a new line of research that prioritizes both performance and computational accessibility, hence democratizing the use of sophisticated AI technologies. Through Wuerstchen, we demonstrate a compelling stride forward in the realm of text-to-image synthesis, offering an innovative path to explore in future research.* +*We introduce Würstchen, a novel technique for text-to-image synthesis that unites competitive performance with unprecedented cost-effectiveness and ease of training on constrained hardware. Building on recent advancements in machine learning, our approach, which utilizes latent diffusion strategies at strong latent image compression rates, significantly reduces the computational burden, typically associated with state-of-the-art models, while preserving, if not enhancing, the quality of generated images. Wuerstchen achieves notable speed improvements at inference time, thereby rendering real-time applications more viable. One of the key advantages of our method lies in its modest training requirements of only 9,200 GPU hours, slashing the usual costs significantly without compromising the end performance. In a comparison against the state-of-the-art, we found the approach to yield strong competitiveness. This paper opens the door to a new line of research that prioritizes both performance and computational accessibility, hence democratizing the use of sophisticated AI technologies. Through Wuerstchen, we demonstrate a compelling stride forward in the realm of text-to-image synthesis, offering an innovative path to explore in future research.* -The original codebase can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen). +## Würstchen v2 comes to Diffusers! +After the initial paper release, we have improved numerous things in the architecture, training and sampling, making Würstchen competetive to current state-of-the-art models in many ways. We are excited to release this new version together with Diffusers. Here is a list of the improvements. +- Higher resolution (1024x1024 up to 2048x2048) +- Faster inference +- Multi Aspect Resolution Sampling +- Better quality + +## Text-to-Image Generation +```python +import torch +from diffusers import WuerstchenGeneratorPipeline, WuerstchenPriorPipeline + +device = "cuda" +dtype = torch.float16 +num_images_per_prompt = 2 + +prior_pipeline = WuerstchenPriorPipeline.from_pretrained( + "warp-diffusion/WuerstchenPriorPipeline", torch_dtype=dtype +).to(device) +generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained( + "warp-diffusion/WuerstchenGeneratorPipeline", torch_dtype=dtype +).to(device) + +caption = "A captivating artwork of a mysterious stone golem" +negative_prompt = "bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated" + +prior_output = prior_pipeline( + prompt=caption, + height=1024, + width=1024, + negative_prompt=negative_prompt, + guidance_scale=8.0, + num_images_per_prompt=num_images_per_prompt, +) +generator_output = generator_pipeline( + predicted_image_embeddings=prior_output.image_embeds, + prompt=caption, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + guidance_scale=0.0, + output_type="pil", +).images + +``` + +## Pipeline Explained +Würstchen consists out of 3 stages: Stage C, Stage B, Stage A. They all have different jobs and work only together. When generating images condidtioned on text, Stage C will first generate the latents in a very compressed latent space. This is what happens in the `prior_pipeline`. Afterwards, the generated latents will be passed to Stage B, which decompresses the latents into a bigger latent space of a VQGAN. These latents can then be decoded by Stage A, which is a VQGAN, into the pixel-space. Stage B & Stage A both happen in the `generator_pipeline`. For more details, take a look the [paper](https://huggingface.co/papers/2306.00637). + + +The original codebase, as well as experimental ideas, can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen). ## WuerstchenPriorPipeline [[autodoc]] WuerstchenGeneratorPipeline diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index d7930cacfe28..2ce3be0ec2e1 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -254,6 +254,11 @@ def __call__( prompt = [prompt] elif not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif not isinstance(negative_prompt, list): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") text_encoder_hidden_states = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt @@ -263,14 +268,14 @@ def __call__( ) dtype = predicted_image_embeddings.dtype - latent_height = int(predicted_image_embeddings.size(2) * (256 / 24)) - latent_width = int(predicted_image_embeddings.size(3) * (256 / 24)) - effnet_features_shape = (predicted_image_embeddings.size(0), 4, latent_height, latent_width) + latent_height = int(predicted_image_embeddings.size(2) * 10.67) + latent_width = int(predicted_image_embeddings.size(3) * 10.67) + latent_features_shape = (predicted_image_embeddings.size(0), 4, latent_height, latent_width) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps - latents = self.prepare_latents(effnet_features_shape, dtype, device, generator, latents) + latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents) for t in self.progress_bar(timesteps[:-1]): ratio = t.expand(latents.size(0)).to(dtype) @@ -297,18 +302,22 @@ def __call__( generator=generator, ).prev_sample + print("1") images = self.vqgan.decode(latents).sample.clamp(0, 1) + print("2") if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `np` and `pil` are supported not output_type={output_type}") if output_type == "np": + print("3") images = images.permute(0, 2, 3, 1).cpu().numpy() + print("4") elif output_type == "pil": images = images.permute(0, 2, 3, 1).cpu().numpy() images = self.numpy_to_pil(images) if not return_dict: return images - + print("5") return ImagePipelineOutput(images) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 1da36701d691..5c0e7c5a546c 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -17,6 +17,7 @@ import numpy as np import torch +from math import ceil from transformers import CLIPTextModel, CLIPTokenizer from ...schedulers import DDPMWuerstchenScheduler @@ -86,7 +87,6 @@ def __init__( scheduler: DDPMWuerstchenScheduler, ) -> None: super().__init__() - self.multiple = 128 self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, @@ -227,6 +227,11 @@ def __call__( prompt = [prompt] elif not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif not isinstance(negative_prompt, list): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") batch_size = len(prompt) if isinstance(prompt, list) else 1 text_encoder_hidden_states = self._encode_prompt( @@ -234,8 +239,10 @@ def __call__( ) dtype = text_encoder_hidden_states.dtype - latent_height = int(128 * (height / 128) / (1024 / 24)) - latent_width = int(128 * (width / 128) / (1024 / 24)) + # latent_height = int(self.multiple * (height / self.multiple) / (1024 / 24)) + latent_height = ceil(height / 42.67) + # latent_width = int(self.multiple * (width / self.multiple) / (1024 / 24)) + latent_width = ceil(width / 42.67) num_channels = self.prior.config.c_in effnet_features_shape = (num_images_per_prompt * batch_size, num_channels, latent_height, latent_width) From 05b58bc52ca1b8c163ca0bcd11d38e3ed62f32ba Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Mon, 7 Aug 2023 16:55:04 +0200 Subject: [PATCH 102/181] small fixes --- .../pipelines/wuerstchen/pipeline_wuerstchen.py | 14 +++++++------- .../wuerstchen/pipeline_wuerstchen_prior.py | 3 +++ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 2ce3be0ec2e1..f7d71eb19fbf 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -239,7 +239,7 @@ def __call__( prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, num_inference_steps: dict[float, int] = {0.0: 12}, - guidance_scale: float = 3.0, + guidance_scale: float = 0.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, @@ -247,9 +247,11 @@ def __call__( return_dict: bool = True, ): device = self._execution_device - do_classifier_free_guidance = guidance_scale > 1.0 + if isinstance(num_inference_steps, int): + num_inference_steps = {0.0: num_inference_steps} + if isinstance(prompt, str): prompt = [prompt] elif not isinstance(prompt, list): @@ -302,22 +304,20 @@ def __call__( generator=generator, ).prev_sample - print("1") images = self.vqgan.decode(latents).sample.clamp(0, 1) - print("2") if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `np` and `pil` are supported not output_type={output_type}") if output_type == "np": - print("3") images = images.permute(0, 2, 3, 1).cpu().numpy() - print("4") elif output_type == "pil": + print(1) + images.permute(0, 2, 3, 1) + print(2) images = images.permute(0, 2, 3, 1).cpu().numpy() images = self.numpy_to_pil(images) if not return_dict: return images - print("5") return ImagePipelineOutput(images) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 5c0e7c5a546c..1cb8d77363e9 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -223,6 +223,9 @@ def __call__( device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 + if isinstance(num_inference_steps, int): + num_inference_steps = {0.0: num_inference_steps} + if isinstance(prompt, str): prompt = [prompt] elif not isinstance(prompt, list): From b0dc35cc28dff2ebefeed31edc544ea3879884ba Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 7 Aug 2023 16:01:09 +0000 Subject: [PATCH 103/181] Add combined pipeline --- src/diffusers/__init__.py | 1 + src/diffusers/pipelines/__init__.py | 2 +- src/diffusers/pipelines/auto_pipeline.py | 3 + .../pipelines/wuerstchen/__init__.py | 1 + .../wuerstchen/pipeline_wuerstchen.py | 11 +- .../pipeline_wuerstchen_combined.py | 239 ++++++++++++++++++ .../wuerstchen/pipeline_wuerstchen_prior.py | 5 - 7 files changed, 247 insertions(+), 15 deletions(-) create mode 100644 src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6ed7a9f5d16e..926b7cdf12f3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -210,6 +210,7 @@ VQDiffusionPipeline, WuerstchenGeneratorPipeline, WuerstchenPriorPipeline, + WuerstchenPipeline, ) try: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 5b68d3e07404..124f64c1fcd6 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -125,7 +125,7 @@ VersatileDiffusionTextToImagePipeline, ) from .vq_diffusion import VQDiffusionPipeline - from .wuerstchen import WuerstchenGeneratorPipeline, WuerstchenPriorPipeline + from .wuerstchen import WuerstchenGeneratorPipeline, WuerstchenPriorPipeline, WuerstchenPipeline try: diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 314a1917af7b..b424246a37f2 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -50,6 +50,7 @@ StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) +from .wuerstchen import WuerstchenPipeline, WuerstchenGeneratorPipeline AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( @@ -61,6 +62,7 @@ ("kandinsky22", KandinskyV22CombinedPipeline), ("stable-diffusion-controlnet", StableDiffusionControlNetPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline), + ("wuerstchen", WuerstchenPipeline), ] ) @@ -90,6 +92,7 @@ [ ("kandinsky", KandinskyPipeline), ("kandinsky22", KandinskyV22Pipeline), + ("wuerstchen", WuerstchenGeneratorPipeline), ] ) _AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict( diff --git a/src/diffusers/pipelines/wuerstchen/__init__.py b/src/diffusers/pipelines/wuerstchen/__init__.py index 877db65d8c9b..3eb48aba9383 100644 --- a/src/diffusers/pipelines/wuerstchen/__init__.py +++ b/src/diffusers/pipelines/wuerstchen/__init__.py @@ -4,4 +4,5 @@ if is_transformers_available() and is_torch_available(): from .modules import DiffNeXt, EfficientNetEncoder, Prior from .pipeline_wuerstchen import WuerstchenGeneratorPipeline + from .pipeline_wuerstchen_combined import WuerstchenPipeline from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index f7d71eb19fbf..907ec36ae09c 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -235,7 +235,7 @@ def encode_image(self, image): @torch.no_grad() def __call__( self, - predicted_image_embeddings: torch.Tensor, + image_embeds: torch.Tensor, prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, num_inference_steps: dict[float, int] = {0.0: 12}, @@ -257,16 +257,12 @@ def __call__( elif not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - if isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] - elif not isinstance(negative_prompt, list): - raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") text_encoder_hidden_states = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) predicted_image_embeddings, text_encoder_hidden_states = self.check_inputs( - predicted_image_embeddings, text_encoder_hidden_states, do_classifier_free_guidance, device + image_embeds, text_encoder_hidden_states, do_classifier_free_guidance, device ) dtype = predicted_image_embeddings.dtype @@ -312,9 +308,6 @@ def __call__( if output_type == "np": images = images.permute(0, 2, 3, 1).cpu().numpy() elif output_type == "pil": - print(1) - images.permute(0, 2, 3, 1) - print(2) images = images.permute(0, 2, 3, 1).cpu().numpy() images = self.numpy_to_pil(images) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py new file mode 100644 index 000000000000..1f2067d45449 --- /dev/null +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -0,0 +1,239 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, List, Optional, Union + +import PIL +import torch +from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel, CLIPTokenizer + +from ...schedulers import DDPMWuerstchenScheduler +from .modules import Prior + +from ...models import VQModelPaella +from ...schedulers import DDPMWuerstchenScheduler +from ...utils import is_accelerate_available, logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .modules import DiffNeXt, EfficientNetEncoder + + + +from ...models import PriorTransformer, UNet2DConditionModel, VQModel +from ...schedulers import DDIMScheduler, DDPMScheduler, UnCLIPScheduler +from ...utils import ( + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from .pipeline_wuerstchen import WuerstchenGeneratorPipeline +from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline + +TEXT2IMAGE_EXAMPLE_DOC_STRING = """ + Examples: + ```py + ``` +""" + +class WuerstchenPipeline(DiffusionPipeline): + """ + Combined Pipeline for text-to-image generation using Wuerstchen + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + TODO + """ + _load_connected_pipes = True + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + generator: DiffNeXt, + scheduler: DDPMWuerstchenScheduler, + vqgan: VQModelPaella, + efficient_net: EfficientNetEncoder, + prior_tokenizer: CLIPTokenizer, + prior_text_encoder: CLIPTextModel, + prior_prior: Prior, + prior_scheduler: DDPMWuerstchenScheduler, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + generator=generator, + scheduler=scheduler, + vqgan=vqgan, + efficient_net=efficient_net, + prior_prior=prior_prior, + prior_text_encoder=prior_text_encoder, + prior_tokenizer=prior_tokenizer, + prior_scheduler=prior_scheduler, + ) + self.prior_pipe = WuerstchenPriorPipeline( + prior=prior_prior, + text_encoder=prior_text_encoder, + tokenizer=prior_tokenizer, + scheduler=prior_scheduler, + ) + self.decoder_pipe = WuerstchenGeneratorPipeline( + text_encoder=text_encoder, + tokenizer=tokenizer, + generator=generator, + scheduler=scheduler, + vqgan=vqgan, + efficient_net=efficient_net, + ) + + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + self.prior_pipe.enable_model_cpu_offload() + self.decoder_pipe.enable_model_cpu_offload() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗 + Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a + GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis. + Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) + + def progress_bar(self, iterable=None, total=None): + self.prior_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.enable_model_cpu_offload() + + def set_progress_bar_config(self, **kwargs): + self.prior_pipe.set_progress_bar_config(**kwargs) + self.decoder_pipe.set_progress_bar_config(**kwargs) + + @torch.no_grad() + @replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + height: int = 512, + width: int = 512, + prior_guidance_scale: float = 4.0, + prior_num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + prior_num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + prior_outputs = self.prior_pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + num_images_per_prompt=num_images_per_prompt, + num_inference_steps=prior_num_inference_steps, + generator=generator, + latents=latents, + guidance_scale=prior_guidance_scale, + output_type="pt", + return_dict=False, + ) + image_embeds = prior_outputs[0] + + prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt + + if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0: + prompt = (image_embeds.shape[0] // len(prompt)) * prompt + + outputs = self.decoder_pipe( + prompt=prompt, + image_embeds=image_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + guidance_scale=guidance_scale, + output_type=output_type, + return_dict=return_dict, + ) + return outputs diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 1cb8d77363e9..68801dca3ef9 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -230,11 +230,6 @@ def __call__( prompt = [prompt] elif not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] - elif not isinstance(negative_prompt, list): - raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") batch_size = len(prompt) if isinstance(prompt, list) else 1 text_encoder_hidden_states = self._encode_prompt( From 3e085301ba63fc03e80162c424cc08c2a7780396 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 7 Aug 2023 16:27:22 +0000 Subject: [PATCH 104/181] rename generator -> decoder --- docs/source/en/api/pipelines/wuerstchen.md | 12 ++++++------ scripts/convert_wuerstchen.py | 6 +++--- src/diffusers/__init__.py | 2 +- src/diffusers/pipelines/__init__.py | 2 +- src/diffusers/pipelines/auto_pipeline.py | 4 ++-- src/diffusers/pipelines/wuerstchen/__init__.py | 2 +- .../pipelines/wuerstchen/pipeline_wuerstchen.py | 8 ++++---- .../wuerstchen/pipeline_wuerstchen_combined.py | 4 ++-- .../utils/dummy_torch_and_transformers_objects.py | 2 +- 9 files changed, 21 insertions(+), 21 deletions(-) diff --git a/docs/source/en/api/pipelines/wuerstchen.md b/docs/source/en/api/pipelines/wuerstchen.md index 1a272ba25529..6679892160a9 100644 --- a/docs/source/en/api/pipelines/wuerstchen.md +++ b/docs/source/en/api/pipelines/wuerstchen.md @@ -16,7 +16,7 @@ After the initial paper release, we have improved numerous things in the archite ## Text-to-Image Generation ```python import torch -from diffusers import WuerstchenGeneratorPipeline, WuerstchenPriorPipeline +from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline device = "cuda" dtype = torch.float16 @@ -25,8 +25,8 @@ num_images_per_prompt = 2 prior_pipeline = WuerstchenPriorPipeline.from_pretrained( "warp-diffusion/WuerstchenPriorPipeline", torch_dtype=dtype ).to(device) -generator_pipeline = WuerstchenGeneratorPipeline.from_pretrained( - "warp-diffusion/WuerstchenGeneratorPipeline", torch_dtype=dtype +generator_pipeline = WuerstchenDecoderPipeline.from_pretrained( + "warp-diffusion/WuerstchenDecoderPipeline", torch_dtype=dtype ).to(device) caption = "A captivating artwork of a mysterious stone golem" @@ -58,14 +58,14 @@ Würstchen consists out of 3 stages: Stage C, Stage B, Stage A. They all have di The original codebase, as well as experimental ideas, can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen). ## WuerstchenPriorPipeline -[[autodoc]] WuerstchenGeneratorPipeline +[[autodoc]] WuerstchenDecoderPipeline - all - __call__ ## WuerstchenPriorPipelineOutput [[autodoc]] WuerstchenPriorPipelineOutput -## WuerstchenGeneratorPipeline -[[autodoc]] WuerstchenGeneratorPipeline +## WuerstchenDecoderPipeline +[[autodoc]] WuerstchenDecoderPipeline - all - __call__ \ No newline at end of file diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 78ae402ea6b5..1e8da8d2c1e8 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -7,7 +7,7 @@ from diffusers import ( DDPMWuerstchenScheduler, VQModelPaella, - WuerstchenGeneratorPipeline, + WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) from diffusers.pipelines.wuerstchen import DiffNeXt, EfficientNetEncoder, Prior @@ -62,7 +62,7 @@ prior_pipeline.save_pretrained("warp-diffusion/WuerstchenPriorPipeline") -generator_pipeline = WuerstchenGeneratorPipeline( +generator_pipeline = WuerstchenDecoderPipeline( text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, @@ -70,4 +70,4 @@ efficient_net=efficient_net, scheduler=scheduler, ) -generator_pipeline.save_pretrained("warp-diffusion/WuerstchenGeneratorPipeline") +generator_pipeline.save_pretrained("warp-diffusion/WuerstchenDecoderPipeline") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 926b7cdf12f3..74ffdac2c764 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -208,7 +208,7 @@ VersatileDiffusionTextToImagePipeline, VideoToVideoSDPipeline, VQDiffusionPipeline, - WuerstchenGeneratorPipeline, + WuerstchenDecoderPipeline, WuerstchenPriorPipeline, WuerstchenPipeline, ) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 124f64c1fcd6..7fccbb18b280 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -125,7 +125,7 @@ VersatileDiffusionTextToImagePipeline, ) from .vq_diffusion import VQDiffusionPipeline - from .wuerstchen import WuerstchenGeneratorPipeline, WuerstchenPriorPipeline, WuerstchenPipeline + from .wuerstchen import WuerstchenDecoderPipeline, WuerstchenPriorPipeline, WuerstchenPipeline try: diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index b424246a37f2..d1b7183dc51f 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -50,7 +50,7 @@ StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) -from .wuerstchen import WuerstchenPipeline, WuerstchenGeneratorPipeline +from .wuerstchen import WuerstchenPipeline, WuerstchenDecoderPipeline AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( @@ -92,7 +92,7 @@ [ ("kandinsky", KandinskyPipeline), ("kandinsky22", KandinskyV22Pipeline), - ("wuerstchen", WuerstchenGeneratorPipeline), + ("wuerstchen", WuerstchenDecoderPipeline), ] ) _AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict( diff --git a/src/diffusers/pipelines/wuerstchen/__init__.py b/src/diffusers/pipelines/wuerstchen/__init__.py index 3eb48aba9383..deb300f1c530 100644 --- a/src/diffusers/pipelines/wuerstchen/__init__.py +++ b/src/diffusers/pipelines/wuerstchen/__init__.py @@ -3,6 +3,6 @@ if is_transformers_available() and is_torch_available(): from .modules import DiffNeXt, EfficientNetEncoder, Prior - from .pipeline_wuerstchen import WuerstchenGeneratorPipeline + from .pipeline_wuerstchen import WuerstchenDecoderPipeline from .pipeline_wuerstchen_combined import WuerstchenPipeline from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 907ec36ae09c..efa8f1c1f1e0 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -31,13 +31,13 @@ Examples: ```py >>> import torch - >>> from diffusers import WuerstchenPriorPipeline, WuerstchenGeneratorPipeline + >>> from diffusers import WuerstchenPriorPipeline, WuerstchenDecoderPipeline >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained( ... "warp-diffusion/WuerstchenPriorPipeline", torch_dtype=torch.float16 ... ).to("cuda") - >>> gen_pipe = WuerstchenGeneratorPipeline.from_pretrain( - ... "warp-diffusion/WuerstchenGeneratorPipeline", torch_dtype=torch.float16 + >>> gen_pipe = WuerstchenDecoderPipeline.from_pretrain( + ... "warp-diffusion/WuerstchenDecoderPipeline", torch_dtype=torch.float16 ... ).to("cuda") >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" @@ -47,7 +47,7 @@ """ -class WuerstchenGeneratorPipeline(DiffusionPipeline): +class WuerstchenDecoderPipeline(DiffusionPipeline): """ Pipeline for generating images from the Wuerstchen model. diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index 1f2067d45449..c90852ec009e 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -35,7 +35,7 @@ replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline -from .pipeline_wuerstchen import WuerstchenGeneratorPipeline +from .pipeline_wuerstchen import WuerstchenDecoderPipeline from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline TEXT2IMAGE_EXAMPLE_DOC_STRING = """ @@ -89,7 +89,7 @@ def __init__( tokenizer=prior_tokenizer, scheduler=prior_scheduler, ) - self.decoder_pipe = WuerstchenGeneratorPipeline( + self.decoder_pipe = WuerstchenDecoderPipeline( text_encoder=text_encoder, tokenizer=tokenizer, generator=generator, diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 6d7fb04777f5..412011aeea0b 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1127,7 +1127,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class WuerstchenGeneratorPipeline(metaclass=DummyObject): +class WuerstchenDecoderPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): From 711246a7ce487566ed874bc970b924ce019694b5 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 7 Aug 2023 21:43:05 +0200 Subject: [PATCH 105/181] Update .gitignore Co-authored-by: Patrick von Platen --- .gitignore | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 2b81433477dd..45602a1f547e 100644 --- a/.gitignore +++ b/.gitignore @@ -173,10 +173,4 @@ tags # ruff .ruff_cache -wandb -scripts/models/ -scripts/warp-diffusion/ -scripts/warp-diffusion-test/ -scripts/samples/ -scripts/text_encoding_colab.pt -scripts/text_encoding.pt \ No newline at end of file +wandb \ No newline at end of file From 5ca1fe0e32c9507e4f31b43fb224c0fd3441485e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 8 Aug 2023 11:13:56 +0200 Subject: [PATCH 106/181] removed efficient_net --- scripts/convert_wuerstchen.py | 16 +++++++-------- .../pipelines/wuerstchen/__init__.py | 2 +- src/diffusers/pipelines/wuerstchen/modules.py | 20 ------------------- .../wuerstchen/pipeline_wuerstchen.py | 15 +++----------- .../pipeline_wuerstchen_combined.py | 19 ++++++------------ .../dummy_torch_and_transformers_objects.py | 15 ++++++++++++++ 6 files changed, 33 insertions(+), 54 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 1e8da8d2c1e8..78d0487f460a 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -10,7 +10,7 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) -from diffusers.pipelines.wuerstchen import DiffNeXt, EfficientNetEncoder, Prior +from diffusers.pipelines.wuerstchen import DiffNeXt, Prior model_path = "models/" @@ -33,17 +33,18 @@ text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") -# EfficientNet -state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device) -efficient_net = EfficientNetEncoder() -efficient_net.load_state_dict(state_dict["effnet_state_dict"]) # Generator +state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device) gen_text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu") gen_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") generator = DiffNeXt() generator.load_state_dict(state_dict["state_dict"]) +# EfficientNet +# efficient_net = EfficientNetEncoder() +# efficient_net.load_state_dict(state_dict["effnet_state_dict"]) + # Prior state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device) prior_model = Prior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device) @@ -62,12 +63,11 @@ prior_pipeline.save_pretrained("warp-diffusion/WuerstchenPriorPipeline") -generator_pipeline = WuerstchenDecoderPipeline( +decoder_pipeline = WuerstchenDecoderPipeline( text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, generator=generator, - efficient_net=efficient_net, scheduler=scheduler, ) -generator_pipeline.save_pretrained("warp-diffusion/WuerstchenDecoderPipeline") +decoder_pipeline.save_pretrained("warp-diffusion/WuerstchenDecoderPipeline") diff --git a/src/diffusers/pipelines/wuerstchen/__init__.py b/src/diffusers/pipelines/wuerstchen/__init__.py index deb300f1c530..c99f5a17a509 100644 --- a/src/diffusers/pipelines/wuerstchen/__init__.py +++ b/src/diffusers/pipelines/wuerstchen/__init__.py @@ -2,7 +2,7 @@ if is_transformers_available() and is_torch_available(): - from .modules import DiffNeXt, EfficientNetEncoder, Prior + from .modules import DiffNeXt, Prior from .pipeline_wuerstchen import WuerstchenDecoderPipeline from .pipeline_wuerstchen_combined import WuerstchenPipeline from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline diff --git a/src/diffusers/pipelines/wuerstchen/modules.py b/src/diffusers/pipelines/wuerstchen/modules.py index 8b32e6b886a6..75749058e2b2 100644 --- a/src/diffusers/pipelines/wuerstchen/modules.py +++ b/src/diffusers/pipelines/wuerstchen/modules.py @@ -112,26 +112,6 @@ def forward(self, x, kv=None): return x -class EfficientNetEncoder(ModelMixin, ConfigMixin): - @register_to_config - def __init__(self, c_latent=16, effnet="efficientnet_v2_s"): - super().__init__() - from torchvision.models import efficientnet_v2_l, efficientnet_v2_s # can't use `torchvision` - - if effnet == "efficientnet_v2_s": - self.backbone = efficientnet_v2_s(weights="DEFAULT").features.eval() - else: - print("Using EffNet L.") - self.backbone = efficientnet_v2_l(weights="DEFAULT").features.eval() - self.mapper = nn.Sequential( - nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), - nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 - ) - - def forward(self, x): - return self.mapper(self.backbone(x)) - - class DiffNeXt(ModelMixin, ConfigMixin): @register_to_config def __init__( diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index efa8f1c1f1e0..1412a394e8cd 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -22,7 +22,7 @@ from ...schedulers import DDPMWuerstchenScheduler from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from .modules import DiffNeXt, EfficientNetEncoder +from .modules import DiffNeXt logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -63,8 +63,6 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): The DiffNeXt unet generator. vqgan ([`VQModelPaella`]): The VQGAN model. - efficient_net ([`EfficientNetEncoder`]): - The EfficientNet encoder. scheduler ([`DDPMWuerstchenScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. """ @@ -76,7 +74,6 @@ def __init__( generator: DiffNeXt, scheduler: DDPMWuerstchenScheduler, vqgan: VQModelPaella, - efficient_net: EfficientNetEncoder, ) -> None: super().__init__() self.register_modules( @@ -85,7 +82,6 @@ def __init__( generator=generator, scheduler=scheduler, vqgan=vqgan, - efficient_net=efficient_net, ) self.register_to_config() @@ -102,7 +98,7 @@ def prepare_latents(self, shape, dtype, device, generator, latents): def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, text_encoder, - vqgan and efficient_net have their state dicts saved to CPU and then are moved to a `torch.device('meta') and + and vqgan have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): @@ -112,7 +108,7 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") - for cpu_offloaded_model in [self.text_encoder, self.vqgan, self.efficient_net]: + for cpu_offloaded_model in [self.text_encoder, self.vqgan]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @@ -228,10 +224,6 @@ def check_inputs( return predicted_image_embeddings, text_encoder_hidden_states - @torch.no_grad() - def encode_image(self, image): - return self.efficient_net(image) - @torch.no_grad() def __call__( self, @@ -256,7 +248,6 @@ def __call__( prompt = [prompt] elif not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - text_encoder_hidden_states = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index c90852ec009e..e0d2467da354 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -19,24 +19,18 @@ from transformers import CLIPTextModel, CLIPTokenizer from ...schedulers import DDPMWuerstchenScheduler -from .modules import Prior - from ...models import VQModelPaella from ...schedulers import DDPMWuerstchenScheduler from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from .modules import DiffNeXt, EfficientNetEncoder - - - from ...models import PriorTransformer, UNet2DConditionModel, VQModel from ...schedulers import DDIMScheduler, DDPMScheduler, UnCLIPScheduler -from ...utils import ( - replace_example_docstring, -) +from ...utils import replace_example_docstring + +from .modules import DiffNeXt, Prior from ..pipeline_utils import DiffusionPipeline from .pipeline_wuerstchen import WuerstchenDecoderPipeline -from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline +from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline TEXT2IMAGE_EXAMPLE_DOC_STRING = """ Examples: @@ -44,6 +38,7 @@ ``` """ + class WuerstchenPipeline(DiffusionPipeline): """ Combined Pipeline for text-to-image generation using Wuerstchen @@ -54,6 +49,7 @@ class WuerstchenPipeline(DiffusionPipeline): Args: TODO """ + _load_connected_pipes = True def __init__( @@ -63,7 +59,6 @@ def __init__( generator: DiffNeXt, scheduler: DDPMWuerstchenScheduler, vqgan: VQModelPaella, - efficient_net: EfficientNetEncoder, prior_tokenizer: CLIPTokenizer, prior_text_encoder: CLIPTextModel, prior_prior: Prior, @@ -77,7 +72,6 @@ def __init__( generator=generator, scheduler=scheduler, vqgan=vqgan, - efficient_net=efficient_net, prior_prior=prior_prior, prior_text_encoder=prior_text_encoder, prior_tokenizer=prior_tokenizer, @@ -95,7 +89,6 @@ def __init__( generator=generator, scheduler=scheduler, vqgan=vqgan, - efficient_net=efficient_net, ) def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 412011aeea0b..8c360571ab84 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1155,3 +1155,18 @@ def from_config(cls, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) + + +class WuerstchenPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) From 1f4bb0abd4a577a7e42c025c9276381c396a518e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 8 Aug 2023 12:00:06 +0200 Subject: [PATCH 107/181] create combined WuerstchenPipeline --- scripts/convert_wuerstchen.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 78d0487f460a..d20c2fb132f4 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -9,6 +9,7 @@ VQModelPaella, WuerstchenDecoderPipeline, WuerstchenPriorPipeline, + WuerstchenPipeline, ) from diffusers.pipelines.wuerstchen import DiffNeXt, Prior @@ -71,3 +72,20 @@ scheduler=scheduler, ) decoder_pipeline.save_pretrained("warp-diffusion/WuerstchenDecoderPipeline") + + +# Wuerstchen pipeline +wuerstchen_pipeline = WuerstchenPipeline( + # Decoder + text_encoder=gen_text_encoder, + tokenizer=gen_tokenizer, + generator=generator, + scheduler=scheduler, + vqgan=vqmodel, + # Prior + prior_tokenizer=tokenizer, + prior_text_encoder=text_encoder, + prior_prior=prior_model, + prior_scheduler=scheduler, +) +wuerstchen_pipeline.save_pretrained("warp-diffusion/Wuerstchen") From 474ec70bef0e7b6e7513d95c9d6698348a3115c1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 11 Aug 2023 15:08:49 +0200 Subject: [PATCH 108/181] make arguments consistent with VQ model --- scripts/convert_wuerstchen.py | 4 ++-- src/diffusers/models/vq_paella.py | 22 +++++++++++----------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index d20c2fb132f4..889df9168752 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -24,8 +24,8 @@ state_dict["vquantizer.embedding.weight"] = state_dict["vquantizer.codebook.weight"] state_dict.pop("vquantizer.codebook.weight") vqmodel = VQModelPaella( - codebook_size=paella_vqmodel.codebook_size, - c_latent=paella_vqmodel.c_latent, + num_vq_embeddings=paella_vqmodel.codebook_size, + latent_channels=paella_vqmodel.c_latent, ) vqmodel.load_state_dict(state_dict) # TODO: test vqmodel outputs match paella_vqmodel outputs diff --git a/src/diffusers/models/vq_paella.py b/src/diffusers/models/vq_paella.py index 0dba4634fb4a..4f35928e17fe 100644 --- a/src/diffusers/models/vq_paella.py +++ b/src/diffusers/models/vq_paella.py @@ -35,9 +35,9 @@ class VQModelPaella(ModelMixin, ConfigMixin): up_down_scale_factor (int, *optional*, defaults to 2): Up and Downscale factor of the input image. levels (int, *optional*, defaults to 2): Number of levels in the model. bottleneck_blocks (int, *optional*, defaults to 12): Number of bottleneck blocks in the model. - c_hidden (int, *optional*, defaults to 384): Number of hidden channels in the model. - c_latent (int, *optional*, defaults to 4): Number of latent channels in the model. - codebook_size (int, *optional*, defaults to 8192): Number of codebook vectors in the VQ-VAE. + embed_dim (int, *optional*, defaults to 384): Number of hidden channels in the model. + latent_channels (int, *optional*, defaults to 4): Number of latent channels in the VQ-VAE model. + num_vq_embeddings (int, *optional*, defaults to 8192): Number of codebook vectors in the VQ-VAE. scale_factor (float, *optional*, defaults to 0.3764): Scaling factor of the latent space. """ @@ -49,14 +49,14 @@ def __init__( up_down_scale_factor: int = 2, levels: int = 2, bottleneck_blocks: int = 12, - c_hidden: int = 384, - c_latent: int = 4, - codebook_size: int = 8192, + embed_dim: int = 384, + latent_channels: int = 4, + num_vq_embeddings: int = 8192, scale_factor: float = 0.3764, ): super().__init__() - c_levels = [c_hidden // (2**i) for i in reversed(range(levels))] + c_levels = [embed_dim // (2**i) for i in reversed(range(levels))] self.in_block = nn.Sequential( nn.PixelUnshuffle(up_down_scale_factor), nn.Conv2d(in_channels * up_down_scale_factor**2, c_levels[0], kernel_size=1), @@ -70,15 +70,15 @@ def __init__( down_blocks.append(block) down_blocks.append( nn.Sequential( - nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), - nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 + nn.Conv2d(c_levels[-1], latent_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(latent_channels), # then normalize them to have mean 0 and std 1 ) ) self.down_blocks = nn.Sequential(*down_blocks) - self.vquantizer = VectorQuantizer(codebook_size, vq_embed_dim=c_latent, legacy=False, beta=0.25) + self.vquantizer = VectorQuantizer(num_vq_embeddings, vq_embed_dim=latent_channels, legacy=False, beta=0.25) # Decoder blocks - up_blocks = [nn.Sequential(nn.Conv2d(c_latent, c_levels[-1], kernel_size=1))] + up_blocks = [nn.Sequential(nn.Conv2d(latent_channels, c_levels[-1], kernel_size=1))] for i in range(levels): for j in range(bottleneck_blocks if i == 0 else 1): block = MixingResidualBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) From 54d339781886e0abfcd2758e810f6e6139b838c6 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 14 Aug 2023 17:06:17 +0200 Subject: [PATCH 109/181] fix var names --- src/diffusers/pipelines/wuerstchen/modules.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/modules.py b/src/diffusers/pipelines/wuerstchen/modules.py index 75749058e2b2..75b1314a90a3 100644 --- a/src/diffusers/pipelines/wuerstchen/modules.py +++ b/src/diffusers/pipelines/wuerstchen/modules.py @@ -92,9 +92,9 @@ def __init__(self, dim): self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) def forward(self, x): - Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) - Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) - return self.gamma * (x * Nx) + self.beta + x + agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * stand_div_norm) + self.beta + x class AttnBlock(nn.Module): From 98f4b547375bd3b8a63116b8a8e2dd35e96c05b2 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 14 Aug 2023 17:15:19 +0200 Subject: [PATCH 110/181] no need to return text_encoder_hidden_states --- .../pipelines/wuerstchen/pipeline_wuerstchen_prior.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 68801dca3ef9..c1aff139048a 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -52,12 +52,10 @@ class WuerstchenPriorPipelineOutput(BaseOutput): Args: image_embeds (`torch.FloatTensor` or `np.ndarray`) Prior image embeddings for text prompt - text_embeds (`torch.FloatTensor` or `np.ndarray`) - Clip text embeddings for unconditional tokens + """ image_embeds: Union[torch.FloatTensor, np.ndarray] - text_embeds: Union[torch.FloatTensor, np.ndarray] class WuerstchenPriorPipeline(DiffusionPipeline): @@ -275,9 +273,8 @@ def __call__( if output_type == "np": latents = latents.cpu().numpy() - text_encoder_hidden_states = text_encoder_hidden_states.cpu().numpy() if not return_dict: - return (latents, text_encoder_hidden_states) + return (latents,) - return WuerstchenPriorPipelineOutput(latents, text_encoder_hidden_states) + return WuerstchenPriorPipelineOutput(latents) From b43b4632ca6567ffffa181f90bdc1ab41e043535 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 14 Aug 2023 17:15:48 +0200 Subject: [PATCH 111/181] add latent_dim_scale to config --- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 1412a394e8cd..fb1fc02d0a4e 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -74,6 +74,7 @@ def __init__( generator: DiffNeXt, scheduler: DDPMWuerstchenScheduler, vqgan: VQModelPaella, + latent_dim_scale: float = 10.67, ) -> None: super().__init__() self.register_modules( @@ -257,8 +258,8 @@ def __call__( ) dtype = predicted_image_embeddings.dtype - latent_height = int(predicted_image_embeddings.size(2) * 10.67) - latent_width = int(predicted_image_embeddings.size(3) * 10.67) + latent_height = int(predicted_image_embeddings.size(2) * self.config.latent_dim_scale) + latent_width = int(predicted_image_embeddings.size(3) * self.config.latent_dim_scale) latent_features_shape = (predicted_image_embeddings.size(0), 4, latent_height, latent_width) self.scheduler.set_timesteps(num_inference_steps, device=device) From bb8c5b1b9ee93c2ce0a532f9a8ae2f405f76e422 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 14 Aug 2023 17:29:54 +0200 Subject: [PATCH 112/181] split model into its own file --- scripts/convert_wuerstchen.py | 2 +- src/diffusers/__init__.py | 2 +- src/diffusers/pipelines/__init__.py | 2 +- .../pipelines/wuerstchen/__init__.py | 3 +- src/diffusers/pipelines/wuerstchen/common.py | 105 ++++++++++++ .../wuerstchen/{modules.py => diffnext.py} | 157 +----------------- .../wuerstchen/pipeline_wuerstchen.py | 2 +- .../pipeline_wuerstchen_combined.py | 20 +-- .../wuerstchen/pipeline_wuerstchen_prior.py | 6 +- .../pipelines/wuerstchen/wuerstchen_prior.py | 58 +++++++ .../schedulers/scheduling_ddpm_wuerstchen.py | 2 +- 11 files changed, 185 insertions(+), 174 deletions(-) create mode 100644 src/diffusers/pipelines/wuerstchen/common.py rename src/diffusers/pipelines/wuerstchen/{modules.py => diffnext.py} (60%) create mode 100644 src/diffusers/pipelines/wuerstchen/wuerstchen_prior.py diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 889df9168752..2faef15c9810 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -8,8 +8,8 @@ DDPMWuerstchenScheduler, VQModelPaella, WuerstchenDecoderPipeline, - WuerstchenPriorPipeline, WuerstchenPipeline, + WuerstchenPriorPipeline, ) from diffusers.pipelines.wuerstchen import DiffNeXt, Prior diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 74ffdac2c764..1960fdb3c2dc 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -209,8 +209,8 @@ VideoToVideoSDPipeline, VQDiffusionPipeline, WuerstchenDecoderPipeline, - WuerstchenPriorPipeline, WuerstchenPipeline, + WuerstchenPriorPipeline, ) try: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 7fccbb18b280..e24308755f99 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -125,7 +125,7 @@ VersatileDiffusionTextToImagePipeline, ) from .vq_diffusion import VQDiffusionPipeline - from .wuerstchen import WuerstchenDecoderPipeline, WuerstchenPriorPipeline, WuerstchenPipeline + from .wuerstchen import WuerstchenDecoderPipeline, WuerstchenPipeline, WuerstchenPriorPipeline try: diff --git a/src/diffusers/pipelines/wuerstchen/__init__.py b/src/diffusers/pipelines/wuerstchen/__init__.py index c99f5a17a509..2403ca6db8f9 100644 --- a/src/diffusers/pipelines/wuerstchen/__init__.py +++ b/src/diffusers/pipelines/wuerstchen/__init__.py @@ -2,7 +2,8 @@ if is_transformers_available() and is_torch_available(): - from .modules import DiffNeXt, Prior + from .diffnext import DiffNeXt from .pipeline_wuerstchen import WuerstchenDecoderPipeline from .pipeline_wuerstchen_combined import WuerstchenPipeline from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline + from .wuerstchen_prior import WuerstchenPrior diff --git a/src/diffusers/pipelines/wuerstchen/common.py b/src/diffusers/pipelines/wuerstchen/common.py new file mode 100644 index 000000000000..6676902af122 --- /dev/null +++ b/src/diffusers/pipelines/wuerstchen/common.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn + + +class LayerNorm2d(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + +class TimestepBlock(nn.Module): + def __init__(self, c, c_timestep): + super().__init__() + self.mapper = nn.Linear(c_timestep, c * 2) + + def forward(self, x, t): + a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1) + return x * (1 + a) + b + + +class Attention2D(nn.Module): + def __init__(self, c, nhead, dropout=0.0): + super().__init__() + self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) + + def forward(self, x, kv=None, self_attn=False): + orig_shape = x.shape + x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 + if self_attn and kv is not None: + kv = torch.cat([x, kv], dim=1) + elif kv is None: + kv = x + x = self.attn(x, kv, kv, need_weights=False)[0] + x = x.permute(0, 2, 1).view(*orig_shape) + return x + + +class ResBlockStageB(nn.Module): + def __init__(self, c, c_skip=None, kernel_size=3, dropout=0.0): + super().__init__() + self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c + c_skip, c * 4), + nn.GELU(), + GlobalResponseNorm(c * 4), + nn.Dropout(dropout), + nn.Linear(c * 4, c), + ) + + def forward(self, x, x_skip=None): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res + + +class ResBlock(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): + super().__init__() + self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c) + ) + + def forward(self, x, x_skip=None): + x_res = x + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1) + x = self.channelwise(x).permute(0, 3, 1, 2) + return x + x_res + + +# from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 +class GlobalResponseNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * stand_div_norm) + self.beta + x + + +class AttnBlock(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention2D(c, nhead, dropout) + self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c)) + + def forward(self, x, kv=None): + if kv is not None: + kv = self.kv_mapper(kv) + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + return x diff --git a/src/diffusers/pipelines/wuerstchen/modules.py b/src/diffusers/pipelines/wuerstchen/diffnext.py similarity index 60% rename from src/diffusers/pipelines/wuerstchen/modules.py rename to src/diffusers/pipelines/wuerstchen/diffnext.py index 75b1314a90a3..3eeb6e6662d4 100644 --- a/src/diffusers/pipelines/wuerstchen/modules.py +++ b/src/diffusers/pipelines/wuerstchen/diffnext.py @@ -4,112 +4,9 @@ import torch import torch.nn as nn -from diffusers.configuration_utils import ConfigMixin, register_to_config - +from ...configuration_utils import ConfigMixin, register_to_config from ...models.modeling_utils import ModelMixin - - -class LayerNorm2d(nn.LayerNorm): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, x): - return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - - -class TimestepBlock(nn.Module): - def __init__(self, c, c_timestep): - super().__init__() - self.mapper = nn.Linear(c_timestep, c * 2) - - def forward(self, x, t): - a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1) - return x * (1 + a) + b - - -class Attention2D(nn.Module): - def __init__(self, c, nhead, dropout=0.0): - super().__init__() - self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) - - def forward(self, x, kv=None, self_attn=False): - orig_shape = x.shape - x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 - if self_attn and kv is not None: - kv = torch.cat([x, kv], dim=1) - elif kv is None: - kv = x - x = self.attn(x, kv, kv, need_weights=False)[0] - x = x.permute(0, 2, 1).view(*orig_shape) - return x - - -class ResBlockStageB(nn.Module): - def __init__(self, c, c_skip=None, kernel_size=3, dropout=0.0): - super().__init__() - self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) - self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) - self.channelwise = nn.Sequential( - nn.Linear(c + c_skip, c * 4), - nn.GELU(), - GlobalResponseNorm(c * 4), - nn.Dropout(dropout), - nn.Linear(c * 4, c), - ) - - def forward(self, x, x_skip=None): - x_res = x - x = self.norm(self.depthwise(x)) - if x_skip is not None: - x = torch.cat([x, x_skip], dim=1) - x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - return x + x_res - - -class ResBlock(nn.Module): - def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): - super().__init__() - self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) - self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) - self.channelwise = nn.Sequential( - nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c) - ) - - def forward(self, x, x_skip=None): - x_res = x - if x_skip is not None: - x = torch.cat([x, x_skip], dim=1) - x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1) - x = self.channelwise(x).permute(0, 3, 1, 2) - return x + x_res - - -# from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 -class GlobalResponseNorm(nn.Module): - def __init__(self, dim): - super().__init__() - self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) - self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) - - def forward(self, x): - agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True) - stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6) - return self.gamma * (x * stand_div_norm) + self.beta + x - - -class AttnBlock(nn.Module): - def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): - super().__init__() - self.self_attn = self_attn - self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) - self.attention = Attention2D(c, nhead, dropout) - self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c)) - - def forward(self, x, kv=None): - if kv is not None: - kv = self.kv_mapper(kv) - x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) - return x +from .common import AttnBlock, LayerNorm2d, ResBlockStageB, TimestepBlock class DiffNeXt(ModelMixin, ConfigMixin): @@ -317,53 +214,3 @@ def forward(self, x, r, effnet, clip=None, x_cat=None, eps=1e-3, return_noise=Tr return (x_in - a) / b else: return a, b - - -class Prior(ModelMixin, ConfigMixin): - @register_to_config - def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, latent_size=(12, 12), dropout=0.1): - super().__init__() - self.c_r = c_r - self.projection = nn.Conv2d(c_in, c, kernel_size=1) - self.cond_mapper = nn.Sequential( - nn.Linear(c_cond, c), - nn.LeakyReLU(0.2), - nn.Linear(c, c), - ) - - self.blocks = nn.ModuleList() - for _ in range(depth): - self.blocks.append(ResBlock(c, dropout=dropout)) - self.blocks.append(TimestepBlock(c, c_r)) - self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout)) - self.out = nn.Sequential( - LayerNorm2d(c, elementwise_affine=False, eps=1e-6), - nn.Conv2d(c, c_in * 2, kernel_size=1), - ) - - def gen_r_embedding(self, r, max_positions=10000): - r = r * max_positions - half_dim = self.c_r // 2 - emb = math.log(max_positions) / (half_dim - 1) - emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() - emb = r[:, None] * emb[None, :] - emb = torch.cat([emb.sin(), emb.cos()], dim=1) - if self.c_r % 2 == 1: # zero pad - emb = nn.functional.pad(emb, (0, 1), mode="constant") - return emb.to(dtype=r.dtype) - - def forward(self, x, r, c): - x_in = x - x = self.projection(x) - c_embed = self.cond_mapper(c) - r_embed = self.gen_r_embedding(r) - for block in self.blocks: - if isinstance(block, AttnBlock): - x = block(x, c_embed) - elif isinstance(block, TimestepBlock): - x = block(x, r_embed) - else: - x = block(x) - a, b = self.out(x).chunk(2, dim=1) - # denoised = a / (1-(1-b).pow(2)).sqrt() - return (x_in - a) / ((1 - b).abs() + 1e-5) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index fb1fc02d0a4e..f35a22a9320d 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -22,7 +22,7 @@ from ...schedulers import DDPMWuerstchenScheduler from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from .modules import DiffNeXt +from .diffnext import DiffNeXt logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index e0d2467da354..ce36ae633126 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -13,28 +13,28 @@ # limitations under the License. from typing import Callable, List, Optional, Union -import PIL import torch from transformers import CLIPTextModel, CLIPTokenizer -from transformers import CLIPTextModel, CLIPTokenizer -from ...schedulers import DDPMWuerstchenScheduler from ...models import VQModelPaella from ...schedulers import DDPMWuerstchenScheduler -from ...utils import is_accelerate_available, logging, randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from ...models import PriorTransformer, UNet2DConditionModel, VQModel -from ...schedulers import DDIMScheduler, DDPMScheduler, UnCLIPScheduler from ...utils import replace_example_docstring - -from .modules import DiffNeXt, Prior from ..pipeline_utils import DiffusionPipeline +from .diffnext import DiffNeXt from .pipeline_wuerstchen import WuerstchenDecoderPipeline from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline +from .wuerstchen_prior import WuerstchenPrior + TEXT2IMAGE_EXAMPLE_DOC_STRING = """ Examples: ```py + >>> from diffusions import WuerstchenPipeline + + >>> pipe = WuerstchenPipeline.from_pretrained("warp-diffusion/Wuerstchen", torch_dtype=torch.float16 + ... ).to("cuda") + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> images = pipe(prompt=prompt) ``` """ @@ -61,7 +61,7 @@ def __init__( vqgan: VQModelPaella, prior_tokenizer: CLIPTokenizer, prior_text_encoder: CLIPTextModel, - prior_prior: Prior, + prior_prior: WuerstchenPrior, prior_scheduler: DDPMWuerstchenScheduler, ): super().__init__() diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index c1aff139048a..9e8556cbac2b 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -13,17 +13,17 @@ # limitations under the License. from dataclasses import dataclass +from math import ceil from typing import List, Optional, Union import numpy as np import torch -from math import ceil from transformers import CLIPTextModel, CLIPTokenizer from ...schedulers import DDPMWuerstchenScheduler from ...utils import BaseOutput, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline -from .modules import Prior +from .wuerstchen_prior import WuerstchenPrior logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -81,7 +81,7 @@ def __init__( self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, - prior: Prior, + prior: WuerstchenPrior, scheduler: DDPMWuerstchenScheduler, ) -> None: super().__init__() diff --git a/src/diffusers/pipelines/wuerstchen/wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/wuerstchen_prior.py new file mode 100644 index 000000000000..a68348faa5e1 --- /dev/null +++ b/src/diffusers/pipelines/wuerstchen/wuerstchen_prior.py @@ -0,0 +1,58 @@ +import math + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin +from .common import AttnBlock, LayerNorm2d, ResBlock, TimestepBlock + + +class WuerstchenPrior(ModelMixin, ConfigMixin): + @register_to_config + def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, latent_size=(12, 12), dropout=0.1): + super().__init__() + self.c_r = c_r + self.projection = nn.Conv2d(c_in, c, kernel_size=1) + self.cond_mapper = nn.Sequential( + nn.Linear(c_cond, c), + nn.LeakyReLU(0.2), + nn.Linear(c, c), + ) + + self.blocks = nn.ModuleList() + for _ in range(depth): + self.blocks.append(ResBlock(c, dropout=dropout)) + self.blocks.append(TimestepBlock(c, c_r)) + self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout)) + self.out = nn.Sequential( + LayerNorm2d(c, elementwise_affine=False, eps=1e-6), + nn.Conv2d(c, c_in * 2, kernel_size=1), + ) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode="constant") + return emb.to(dtype=r.dtype) + + def forward(self, x, r, c): + x_in = x + x = self.projection(x) + c_embed = self.cond_mapper(c) + r_embed = self.gen_r_embedding(r) + for block in self.blocks: + if isinstance(block, AttnBlock): + x = block(x, c_embed) + elif isinstance(block, TimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + a, b = self.out(x).chunk(2, dim=1) + # denoised = a / (1-(1-b).pow(2)).sqrt() + return (x_in - a) / ((1 - b).abs() + 1e-5) diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py index 9821be44ecfe..43b28ef51d4d 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py +++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch From 2ee6ed9566078228cbdaf7a0727921d627c92a68 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 14 Aug 2023 18:01:34 +0200 Subject: [PATCH 113/181] add WuerschenPipeline to docs --- docs/source/en/api/pipelines/wuerstchen.md | 53 ++++++++++++++++++++-- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/docs/source/en/api/pipelines/wuerstchen.md b/docs/source/en/api/pipelines/wuerstchen.md index 6679892160a9..12dc2b6d5501 100644 --- a/docs/source/en/api/pipelines/wuerstchen.md +++ b/docs/source/en/api/pipelines/wuerstchen.md @@ -6,14 +6,19 @@ The abstract from the paper is: *We introduce Würstchen, a novel technique for text-to-image synthesis that unites competitive performance with unprecedented cost-effectiveness and ease of training on constrained hardware. Building on recent advancements in machine learning, our approach, which utilizes latent diffusion strategies at strong latent image compression rates, significantly reduces the computational burden, typically associated with state-of-the-art models, while preserving, if not enhancing, the quality of generated images. Wuerstchen achieves notable speed improvements at inference time, thereby rendering real-time applications more viable. One of the key advantages of our method lies in its modest training requirements of only 9,200 GPU hours, slashing the usual costs significantly without compromising the end performance. In a comparison against the state-of-the-art, we found the approach to yield strong competitiveness. This paper opens the door to a new line of research that prioritizes both performance and computational accessibility, hence democratizing the use of sophisticated AI technologies. Through Wuerstchen, we demonstrate a compelling stride forward in the realm of text-to-image synthesis, offering an innovative path to explore in future research.* -## Würstchen v2 comes to Diffusers! +## Würstchen v2 comes to Diffusers + After the initial paper release, we have improved numerous things in the architecture, training and sampling, making Würstchen competetive to current state-of-the-art models in many ways. We are excited to release this new version together with Diffusers. Here is a list of the improvements. + - Higher resolution (1024x1024 up to 2048x2048) - Faster inference - Multi Aspect Resolution Sampling - Better quality ## Text-to-Image Generation + +For the sake of explanation, since the model consists of different stages we will perform generation manually as: + ```python import torch from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline @@ -23,7 +28,7 @@ dtype = torch.float16 num_images_per_prompt = 2 prior_pipeline = WuerstchenPriorPipeline.from_pretrained( - "warp-diffusion/WuerstchenPriorPipeline", torch_dtype=dtype + "warp-diffusion/WuerstchenPriorPipeline", torch_dtype=dtype ).to(device) generator_pipeline = WuerstchenDecoderPipeline.from_pretrained( "warp-diffusion/WuerstchenDecoderPipeline", torch_dtype=dtype @@ -52,20 +57,60 @@ generator_output = generator_pipeline( ``` ## Pipeline Explained -Würstchen consists out of 3 stages: Stage C, Stage B, Stage A. They all have different jobs and work only together. When generating images condidtioned on text, Stage C will first generate the latents in a very compressed latent space. This is what happens in the `prior_pipeline`. Afterwards, the generated latents will be passed to Stage B, which decompresses the latents into a bigger latent space of a VQGAN. These latents can then be decoded by Stage A, which is a VQGAN, into the pixel-space. Stage B & Stage A both happen in the `generator_pipeline`. For more details, take a look the [paper](https://huggingface.co/papers/2306.00637). +Würstchen consists out of 3 stages: Stage C, Stage B, Stage A. They all have different jobs and work only together. When generating images conditioned on text, Stage C will first generate the latents in a very compressed latent space. This is what happens in the `prior_pipeline`. Afterwards, the generated latents will be passed to Stage B, which decompresses the latents into a bigger latent space of a VQGAN. These latents can then be decoded by Stage A, which is a VQGAN, into the pixel-space. Stage B & Stage A both happen in the `generator_pipeline`. For more details, take a look the [paper](https://huggingface.co/papers/2306.00637). + +## Combined Pipeline + +For the sake of usability we have combined the two pipelines into one. This pipeline is called `WuerstchenPipeline` and can be used as follows: + +```python +import torch +from diffusers import WuerstchenPipeline + +device = "cuda" +dtype = torch.float16 +num_images_per_prompt = 2 + +pipeline = WuerstchenPipeline.from_pretrained( + "warp-diffusion/WuerstchenPipeline", torch_dtype=dtype +).to(device) + +caption = "A captivating artwork of a mysterious stone golem" +negative_prompt = "bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated" + +output = pipeline( + prompt=caption, + height=1024, + width=1024, + negative_prompt=negative_prompt, + guidance_scale=8.0, + num_images_per_prompt=num_images_per_prompt, + output_type="pil", +).images +``` The original codebase, as well as experimental ideas, can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen). +## WuerschenPipeline + +[[autodoc]] WuerstchenPipeline + - all + - __call__ + ## WuerstchenPriorPipeline + [[autodoc]] WuerstchenDecoderPipeline + - all - __call__ ## WuerstchenPriorPipelineOutput + [[autodoc]] WuerstchenPriorPipelineOutput ## WuerstchenDecoderPipeline + [[autodoc]] WuerstchenDecoderPipeline - all - - __call__ \ No newline at end of file + - __call__ From d944bb126e53d57be6115ac2881eed76716e453a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 15 Aug 2023 09:48:02 +0200 Subject: [PATCH 114/181] remove unused latent_size --- src/diffusers/pipelines/wuerstchen/wuerstchen_prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wuerstchen/wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/wuerstchen_prior.py index a68348faa5e1..462851d08ff0 100644 --- a/src/diffusers/pipelines/wuerstchen/wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/wuerstchen_prior.py @@ -10,7 +10,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin): @register_to_config - def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, latent_size=(12, 12), dropout=0.1): + def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1): super().__init__() self.c_r = c_r self.projection = nn.Conv2d(c_in, c, kernel_size=1) From 59eb765d52c136917589a26a92575a262f7a15e2 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 15 Aug 2023 10:02:52 +0200 Subject: [PATCH 115/181] register latent_dim_scale --- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index f35a22a9320d..1ea09c77c79f 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -84,7 +84,7 @@ def __init__( scheduler=scheduler, vqgan=vqgan, ) - self.register_to_config() + self.register_to_config(latent_dim_scale=latent_dim_scale) def prepare_latents(self, shape, dtype, device, generator, latents): if latents is None: From b5ff6815f0fe06344656bac98131014dd0032075 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 15 Aug 2023 10:05:31 +0200 Subject: [PATCH 116/181] update script --- scripts/convert_wuerstchen.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 2faef15c9810..67b406efde06 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -11,7 +11,7 @@ WuerstchenPipeline, WuerstchenPriorPipeline, ) -from diffusers.pipelines.wuerstchen import DiffNeXt, Prior +from diffusers.pipelines.wuerstchen import DiffNeXt, WuerstchenPrior model_path = "models/" @@ -48,7 +48,7 @@ # Prior state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device) -prior_model = Prior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device) +prior_model = WuerstchenPrior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device) prior_model.load_state_dict(state_dict["ema_state_dict"]) # scheduler @@ -88,4 +88,4 @@ prior_prior=prior_model, prior_scheduler=scheduler, ) -wuerstchen_pipeline.save_pretrained("warp-diffusion/Wuerstchen") +wuerstchen_pipeline.save_pretrained("warp-diffusion/WuerstchenPipeline") From 7d6f2e0f75bebcfdb8bb4ff0470d81584c898907 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 15 Aug 2023 10:07:34 +0200 Subject: [PATCH 117/181] update docstring --- scripts/convert_wuerstchen.py | 7 +------ src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 2 ++ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 67b406efde06..66d2e6a291ba 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -1,3 +1,4 @@ +# Run inside root directory of official source code import os import torch @@ -28,13 +29,11 @@ latent_channels=paella_vqmodel.c_latent, ) vqmodel.load_state_dict(state_dict) -# TODO: test vqmodel outputs match paella_vqmodel outputs # Clip Text encoder and tokenizer text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") - # Generator state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device) gen_text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu") @@ -42,10 +41,6 @@ generator = DiffNeXt() generator.load_state_dict(state_dict["state_dict"]) -# EfficientNet -# efficient_net = EfficientNetEncoder() -# efficient_net.load_state_dict(state_dict["effnet_state_dict"]) - # Prior state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device) prior_model = WuerstchenPrior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 1ea09c77c79f..da64bbdc07a0 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -65,6 +65,8 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): The VQGAN model. scheduler ([`DDPMWuerstchenScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. + latent_dim_scale (float, `optional`, defaults to 10.67): + The scale of the latent dimension. This is used to determine the size of the latent space. """ def __init__( From f33dd228521114601c18c3afaead57d2b179b72e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 15 Aug 2023 13:25:24 +0200 Subject: [PATCH 118/181] use Attention preprocessor --- scripts/convert_wuerstchen.py | 66 ++++++++++++++------ src/diffusers/pipelines/wuerstchen/common.py | 26 +++----- 2 files changed, 55 insertions(+), 37 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 66d2e6a291ba..3bceb51968d1 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -24,10 +24,7 @@ state_dict["vquantizer.embedding.weight"] = state_dict["vquantizer.codebook.weight"] state_dict.pop("vquantizer.codebook.weight") -vqmodel = VQModelPaella( - num_vq_embeddings=paella_vqmodel.codebook_size, - latent_channels=paella_vqmodel.c_latent, -) +vqmodel = VQModelPaella(num_vq_embeddings=paella_vqmodel.codebook_size, latent_channels=paella_vqmodel.c_latent) vqmodel.load_state_dict(state_dict) # Clip Text encoder and tokenizer @@ -35,40 +32,73 @@ tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") # Generator -state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device) gen_text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu") gen_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + +orig_state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device)["state_dict"] +state_dict = {} +for key in orig_state_dict.keys(): + if key.endswith("in_proj_weight"): + weights = orig_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] + state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] + state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] + elif key.endswith("in_proj_bias"): + weights = orig_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] + state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] + state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] + elif key.endswith("out_proj.weight"): + weights = orig_state_dict[key] + state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights + elif key.endswith("out_proj.bias"): + weights = orig_state_dict[key] + state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights + else: + state_dict[key] = orig_state_dict[key] generator = DiffNeXt() -generator.load_state_dict(state_dict["state_dict"]) +generator.load_state_dict(state_dict) # Prior -state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device) +orig_state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device)["ema_state_dict"] +state_dict = {} +for key in orig_state_dict.keys(): + if key.endswith("in_proj_weight"): + weights = orig_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] + state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] + state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] + elif key.endswith("in_proj_bias"): + weights = orig_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] + state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] + state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] + elif key.endswith("out_proj.weight"): + weights = orig_state_dict[key] + state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights + elif key.endswith("out_proj.bias"): + weights = orig_state_dict[key] + state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights + else: + state_dict[key] = orig_state_dict[key] prior_model = WuerstchenPrior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device) -prior_model.load_state_dict(state_dict["ema_state_dict"]) +prior_model.load_state_dict(state_dict) # scheduler scheduler = DDPMWuerstchenScheduler() # Prior pipeline prior_pipeline = WuerstchenPriorPipeline( - prior=prior_model, - text_encoder=text_encoder, - tokenizer=tokenizer, - scheduler=scheduler, + prior=prior_model, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler ) prior_pipeline.save_pretrained("warp-diffusion/WuerstchenPriorPipeline") decoder_pipeline = WuerstchenDecoderPipeline( - text_encoder=gen_text_encoder, - tokenizer=gen_tokenizer, - vqgan=vqmodel, - generator=generator, - scheduler=scheduler, + text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, generator=generator, scheduler=scheduler ) decoder_pipeline.save_pretrained("warp-diffusion/WuerstchenDecoderPipeline") - # Wuerstchen pipeline wuerstchen_pipeline = WuerstchenPipeline( # Decoder diff --git a/src/diffusers/pipelines/wuerstchen/common.py b/src/diffusers/pipelines/wuerstchen/common.py index 6676902af122..a1b26f62d49f 100644 --- a/src/diffusers/pipelines/wuerstchen/common.py +++ b/src/diffusers/pipelines/wuerstchen/common.py @@ -1,6 +1,8 @@ import torch import torch.nn as nn +from ...models.attention_processor import Attention + class LayerNorm2d(nn.LayerNorm): def __init__(self, *args, **kwargs): @@ -20,23 +22,6 @@ def forward(self, x, t): return x * (1 + a) + b -class Attention2D(nn.Module): - def __init__(self, c, nhead, dropout=0.0): - super().__init__() - self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) - - def forward(self, x, kv=None, self_attn=False): - orig_shape = x.shape - x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 - if self_attn and kv is not None: - kv = torch.cat([x, kv], dim=1) - elif kv is None: - kv = x - x = self.attn(x, kv, kv, need_weights=False)[0] - x = x.permute(0, 2, 1).view(*orig_shape) - return x - - class ResBlockStageB(nn.Module): def __init__(self, c, c_skip=None, kernel_size=3, dropout=0.0): super().__init__() @@ -95,11 +80,14 @@ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): super().__init__() self.self_attn = self_attn self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) - self.attention = Attention2D(c, nhead, dropout) + self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True) self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c)) def forward(self, x, kv=None): if kv is not None: kv = self.kv_mapper(kv) - x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + if self.self_attn and kv is not None: + batch_size, channel, height, width = x.shape + kv = torch.cat([x.view(batch_size, channel, height * width).transpose(1, 2), kv], dim=1) + x = x + self.attention(self.norm(x), encoder_hidden_states=kv) return x From 17d28e3809e6f4898decccbd772d5a6dd9c199ce Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 17 Aug 2023 22:16:39 +0200 Subject: [PATCH 119/181] concat with normed input --- src/diffusers/models/attention_processor.py | 2 +- src/diffusers/pipelines/wuerstchen/common.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 43497c2284ac..e4a534b188a5 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -187,7 +187,7 @@ def set_use_memory_efficient_attention_xformers( if use_memory_efficient_attention_xformers: if is_added_kv_processor and (is_lora or is_custom_diffusion): raise NotImplementedError( - f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}" + f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}" ) if not is_xformers_available(): raise ModuleNotFoundError( diff --git a/src/diffusers/pipelines/wuerstchen/common.py b/src/diffusers/pipelines/wuerstchen/common.py index a1b26f62d49f..57fe258a7be1 100644 --- a/src/diffusers/pipelines/wuerstchen/common.py +++ b/src/diffusers/pipelines/wuerstchen/common.py @@ -83,11 +83,11 @@ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True) self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c)) - def forward(self, x, kv=None): - if kv is not None: - kv = self.kv_mapper(kv) - if self.self_attn and kv is not None: - batch_size, channel, height, width = x.shape - kv = torch.cat([x.view(batch_size, channel, height * width).transpose(1, 2), kv], dim=1) - x = x + self.attention(self.norm(x), encoder_hidden_states=kv) + def forward(self, x, kv): + kv = self.kv_mapper(kv) + norm_x = self.norm(x) + if self.self_attn: + batch_size, channel, _, _ = x.shape + kv = torch.cat([norm_x.view(batch_size, channel, -1).transpose(1, 2), kv], dim=1) + x = x + self.attention(norm_x, encoder_hidden_states=kv) return x From 3f49d5250aea59e1b9141b00bcf1aaea7e0c1e75 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 18 Aug 2023 10:57:58 +0200 Subject: [PATCH 120/181] fix-copies --- src/diffusers/pipelines/auto_pipeline.py | 2 +- src/diffusers/utils/dummy_torch_and_transformers_objects.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index d1b7183dc51f..58a73fe9ecc0 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -50,7 +50,7 @@ StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) -from .wuerstchen import WuerstchenPipeline, WuerstchenDecoderPipeline +from .wuerstchen import WuerstchenDecoderPipeline, WuerstchenPipeline AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index d46744648c02..102cc3219f10 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1157,7 +1157,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class WuerstchenPriorPipeline(metaclass=DummyObject): +class WuerstchenPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -1172,7 +1172,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class WuerstchenPipeline(metaclass=DummyObject): +class WuerstchenPriorPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): From 5ea91e1d096082b4c1df5acf4bda796d7777197d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 18 Aug 2023 11:08:49 +0200 Subject: [PATCH 121/181] add docs --- .../pipeline_wuerstchen_combined.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index ce36ae633126..6b5ee7e8b683 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -47,7 +47,24 @@ class WuerstchenPipeline(DiffusionPipeline): library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: - TODO + tokenizer (:class:`~transformers.CLIPTokenizer`): + The decoder tokenizer to be used for text inputs. + text_encoder (:class:`~transformers.CLIPTextModel`): + The decoder text encoder to be used for text inputs. + generator (:class:`~diffusions.models.DiffNeXt`): + The generator model to be used for decoder image generation pipeline. + scheduler (:class:`~diffusions.schedulers.DDPMWuerstchenScheduler`): + The scheduler to be used for decoder image generation pipeline. + vqgan (:class:`~diffusions.models.VQModelPaella`): + The VQGAN model to be used for decoder image generation pipeline. + prior_tokenizer (:class:`~transformers.CLIPTokenizer`): + The prior tokenizer to be used for text inputs. + prior_text_encoder (:class:`~transformers.CLIPTextModel`): + The prior text encoder to be used for text inputs. + prior_prior (:class:`~diffusions.pipelines.wuerstchen.wuerstchen_prior.WuerstchenPrior`): + The prior model to be used for prior pipeline. + prior_scheduler (:class:`~diffusions.schedulers.DDPMWuerstchenScheduler`): + The scheduler to be used for prior pipeline. """ _load_connected_pipes = True From 752d3f57a03eda5283ecae4fee0d881f5cf081c7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 18 Aug 2023 12:18:58 +0200 Subject: [PATCH 122/181] fix test --- .../wuerstchen/test_wuerstchen_prior.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py index 07db13634136..c6850365aef9 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py @@ -24,7 +24,7 @@ ) from diffusers import DDPMWuerstchenScheduler, WuerstchenPriorPipeline -from diffusers.pipelines.wuerstchen import Prior +from diffusers.pipelines.wuerstchen import WuerstchenPrior from diffusers.utils import torch_device from diffusers.utils.testing_utils import enable_full_determinism, skip_mps @@ -85,7 +85,7 @@ def dummy_text_encoder(self): pad_token_id=1, vocab_size=1000, ) - return CLIPTextModel(config) + return CLIPTextModel(config).eval() @property def dummy_prior(self): @@ -98,11 +98,10 @@ def dummy_prior(self): "c_cond": 32, "c_r": 8, "nhead": 2, - "latent_size": (2, 2), } - model = Prior(**model_kwargs) - return model + model = WuerstchenPrior(**model_kwargs) + return model.eval() def get_dummy_components(self): prior = self.dummy_prior @@ -158,9 +157,19 @@ def test_wuerstchen_prior(self): assert image.shape == (1, 2, 24, 24) expected_slice = np.array( - [-0.0532, 1.7120, 0.3656, -1.0852, -0.8946, -1.1756, 0.4348, 0.2482, 0.5146, -0.1156] + [ + -7172.9814, + -3438.9731, + -1093.4564, + 388.91516, + -7471.7383, + -7998.2944, + -5328.388, + 218.0543, + -2731.6716, + -8056.8545, + ], ) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 From eadd628680ec5112eb5481fd8748197183382d78 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 18 Aug 2023 12:22:29 +0200 Subject: [PATCH 123/181] fix style --- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 4 ++-- .../pipelines/wuerstchen/pipeline_wuerstchen_combined.py | 5 +++-- tests/pipelines/wuerstchen/test_wuerstchen_prior.py | 6 +----- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index da64bbdc07a0..cdccf2f6398d 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -101,8 +101,8 @@ def prepare_latents(self, shape, dtype, device, generator, latents): def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, text_encoder, - and vqgan have their state dicts saved to CPU and then are moved to a `torch.device('meta') and - loaded to GPU only when their specific submodule has its `forward` method called. + and vqgan have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU + only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index 6b5ee7e8b683..aa4666eefc8d 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -31,8 +31,9 @@ ```py >>> from diffusions import WuerstchenPipeline - >>> pipe = WuerstchenPipeline.from_pretrained("warp-diffusion/Wuerstchen", torch_dtype=torch.float16 - ... ).to("cuda") + >>> pipe = WuerstchenPipeline.from_pretrained("warp-diffusion/Wuerstchen", torch_dtype=torch.float16).to( + ... "cuda" + ... ) >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" >>> images = pipe(prompt=prompt) ``` diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py index c6850365aef9..67756509a896 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py @@ -17,11 +17,7 @@ import numpy as np import torch -from transformers import ( - CLIPTextConfig, - CLIPTextModel, - CLIPTokenizer, -) +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import DDPMWuerstchenScheduler, WuerstchenPriorPipeline from diffusers.pipelines.wuerstchen import WuerstchenPrior From fda1f68308097b10f44d413a26f7109e7fd99d97 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 18 Aug 2023 12:28:12 +0200 Subject: [PATCH 124/181] add to cpu_offloaded_model --- src/diffusers/pipelines/wuerstchen/common.py | 4 +++- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 6 +++--- .../pipelines/wuerstchen/pipeline_wuerstchen_prior.py | 6 +++--- tests/pipelines/wuerstchen/test_wuerstchen_prior.py | 5 +---- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/common.py b/src/diffusers/pipelines/wuerstchen/common.py index 57fe258a7be1..d2650110eb67 100644 --- a/src/diffusers/pipelines/wuerstchen/common.py +++ b/src/diffusers/pipelines/wuerstchen/common.py @@ -9,7 +9,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, x): - return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + x = x.permute(0, 2, 3, 1) + x = super().forward(x) + return x.permute(0, 3, 1, 2) class TimestepBlock(nn.Module): diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index cdccf2f6398d..29ae1fe3cf6d 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -101,8 +101,8 @@ def prepare_latents(self, shape, dtype, device, generator, latents): def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, text_encoder, - and vqgan have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU - only when their specific submodule has its `forward` method called. + generator, and vqgan have their state dicts saved to CPU and then are moved to a `torch.device('meta') and + loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload @@ -111,7 +111,7 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") - for cpu_offloaded_model in [self.text_encoder, self.vqgan]: + for cpu_offloaded_model in [self.text_encoder, self.vqgan, self.generator]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 9e8556cbac2b..f41653e8dc36 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -96,8 +96,8 @@ def __init__( def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the text_encoder - have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when - their specific submodule has its `forward` method called. + and the prior have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to + GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload @@ -106,7 +106,7 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") - for cpu_offloaded_model in [self.text_encoder]: + for cpu_offloaded_model in [self.text_encoder, self.prior]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py index 67756509a896..254222bdf43c 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py @@ -142,10 +142,7 @@ def test_wuerstchen_prior(self): output = pipe(**self.get_dummy_inputs(device)) image = output.image_embeds - image_from_tuple = pipe( - **self.get_dummy_inputs(device), - return_dict=False, - )[0] + image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0] image_slice = image[0, 0, 0, -10:] image_from_tuple_slice = image_from_tuple[0, 0, 0, -10:] From f8ddabdf7a41f6b48c160cd7cf288928c0376568 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 18 Aug 2023 12:32:27 +0200 Subject: [PATCH 125/181] updated type --- src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py index 43b28ef51d4d..40140942ea32 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py +++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import torch @@ -136,14 +136,14 @@ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = def set_timesteps( self, - num_inference_steps: dict[float, int], + num_inference_steps: Dict[float, int], device: Union[str, torch.device] = None, ): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: - num_inference_steps (`dict[float, int]`): + num_inference_steps (`Dict[float, int]`): the number of diffusion steps used when generating samples with a pre-trained model. If passed, then `timesteps` must be `None`. device (`str` or `torch.device`, optional): From 45dcfe1a1ba73609f2a2ebfcf0f2ffa7fe63f2a8 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 18 Aug 2023 12:43:49 +0200 Subject: [PATCH 126/181] remove 1-line func --- src/diffusers/models/resnet.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 0a38ead6ad83..f7d63167783e 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -690,18 +690,14 @@ def __init__(self, inp_channels, c_hidden): nn.GELU(), nn.Linear(c_hidden, inp_channels), ) - + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) - @staticmethod - def _norm(x, norm): - return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - def forward(self, x): mods = self.gammas - x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] + x_temp = self.norm1(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[0]) + mods[1] x = x + self.depthwise(x_temp) * mods[2] - x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] + x_temp = self.norm2(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[3]) + mods[4] x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] return x From cbf8780f784486e442e06a8018ae370c724067fe Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 18 Aug 2023 13:04:32 +0200 Subject: [PATCH 127/181] updated type --- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 4 ++-- .../pipelines/wuerstchen/pipeline_wuerstchen_prior.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 29ae1fe3cf6d..00ca503fa0b2 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np import torch @@ -233,7 +233,7 @@ def __call__( image_embeds: torch.Tensor, prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, - num_inference_steps: dict[float, int] = {0.0: 12}, + num_inference_steps: Dict[float, int] = {0.0: 12}, guidance_scale: float = 0.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index f41653e8dc36..ad0def8cba26 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -14,7 +14,7 @@ from dataclasses import dataclass from math import ceil -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np import torch @@ -209,7 +209,7 @@ def __call__( prompt: Union[str, List[str]] = None, height: int = 1024, width: int = 1024, - num_inference_steps: dict[float, int] = {2 / 3: 20, 0.0: 10}, + num_inference_steps: Dict[float, int] = {2 / 3: 20, 0.0: 10}, guidance_scale: float = 8.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, From 15b5f424c44a68be07867b6187ef1ca73be1eba3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 18 Aug 2023 13:04:42 +0200 Subject: [PATCH 128/181] initial decoder test --- .../wuerstchen/test_wuerstchen_decoder.py | 198 ++++++++++++++++++ 1 file changed, 198 insertions(+) create mode 100644 tests/pipelines/wuerstchen/test_wuerstchen_decoder.py diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py new file mode 100644 index 000000000000..2af040058172 --- /dev/null +++ b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py @@ -0,0 +1,198 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import DDPMWuerstchenScheduler, WuerstchenDecoderPipeline +from diffusers.pipelines.wuerstchen import DiffNeXt +from diffusers.models import VQModelPaella +from diffusers.utils import torch_device +from diffusers.utils.testing_utils import enable_full_determinism, skip_mps + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class WuerstchenDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = WuerstchenDecoderPipeline + params = ["prompt"] + batch_params = ["prompt", "negative_prompt"] + required_optional_params = [ + "num_images_per_prompt", + "generator", + "num_inference_steps", + "latents", + "negative_prompt", + "guidance_scale", + "output_type", + "return_dict", + ] + test_xformers_attention = False + + @property + def text_embedder_hidden_size(self): + return 32 + + @property + def time_input_dim(self): + return 32 + + @property + def block_out_channels_0(self): + return self.time_input_dim + + @property + def time_embed_dim(self): + return self.time_input_dim * 4 + + @property + def dummy_tokenizer(self): + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + return tokenizer + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=self.text_embedder_hidden_size, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModel(config).eval() + + @property + def dummy_vqgan(self): + torch.manual_seed(0) + + model_kwargs = { + "in_channels": 3, + } + model = VQModelPaella(**model_kwargs) + return model.eval() + + @property + def dummy_generator(self): + torch.manual_seed(0) + + model_kwargs = { + "c_in": 4, + } + + model = DiffNeXt(**model_kwargs) + return model.eval() + + def get_dummy_components(self): + generator = self.dummy_generator + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + vqgan = self.dummy_vqgan + + scheduler = DDPMWuerstchenScheduler() + + components = { + "generator": generator, + "vqgan": vqgan, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "scheduler": scheduler, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "image_embeds": torch.ones(1, 3, 32, 32, device=device), + "prompt": "horse", + "generator": generator, + "guidance_scale": 1.0, + "num_inference_steps": {0.0: 2}, + "output_type": "np", + } + return inputs + + def test_wuerstchen_decoder(self): + device = "cpu" + + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(device)) + image = output.image_embeds + + image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0] + + image_slice = image[0, 0, 0, -10:] + image_from_tuple_slice = image_from_tuple[0, 0, 0, -10:] + + assert image.shape == (1, 2, 24, 24) + + expected_slice = np.array( + [ + -7172.9814, + -3438.9731, + -1093.4564, + 388.91516, + -7471.7383, + -7998.2944, + -5328.388, + 218.0543, + -2731.6716, + -8056.8545, + ], + ) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + @skip_mps + def test_inference_batch_single_identical(self): + test_max_difference = torch_device == "cpu" + relax_max_difference = True + test_mean_pixel_difference = False + + self._test_inference_batch_single_identical( + test_max_difference=test_max_difference, + relax_max_difference=relax_max_difference, + test_mean_pixel_difference=test_mean_pixel_difference, + ) + + @skip_mps + def test_attention_slicing_forward_pass(self): + test_max_difference = torch_device == "cpu" + test_mean_pixel_difference = False + + self._test_attention_slicing_forward_pass( + test_max_difference=test_max_difference, + test_mean_pixel_difference=test_mean_pixel_difference, + ) From bd362dfb43dc99f081ae36d4fef16d214b173ea7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 18 Aug 2023 13:07:12 +0200 Subject: [PATCH 129/181] formatting --- src/diffusers/models/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index f7d63167783e..4e4218a96f16 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -690,7 +690,7 @@ def __init__(self, inp_channels, c_hidden): nn.GELU(), nn.Linear(c_hidden, inp_channels), ) - + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) def forward(self, x): From d20f8ff00a873d3199a6cba3941c86dd71941f5e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 18 Aug 2023 15:12:30 +0200 Subject: [PATCH 130/181] formatting --- tests/pipelines/wuerstchen/test_wuerstchen_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py index 2af040058172..fad16819604a 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py @@ -20,8 +20,8 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import DDPMWuerstchenScheduler, WuerstchenDecoderPipeline -from diffusers.pipelines.wuerstchen import DiffNeXt from diffusers.models import VQModelPaella +from diffusers.pipelines.wuerstchen import DiffNeXt from diffusers.utils import torch_device from diffusers.utils.testing_utils import enable_full_determinism, skip_mps From 851705c983eabb8d3057016f7127e4ce0c2f3d06 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 18 Aug 2023 15:24:31 +0200 Subject: [PATCH 131/181] fix autodoc link --- docs/source/en/api/pipelines/wuerstchen.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/wuerstchen.md b/docs/source/en/api/pipelines/wuerstchen.md index 12dc2b6d5501..e391a23d92c1 100644 --- a/docs/source/en/api/pipelines/wuerstchen.md +++ b/docs/source/en/api/pipelines/wuerstchen.md @@ -107,7 +107,7 @@ The original codebase, as well as experimental ideas, can be found at [dome272/W ## WuerstchenPriorPipelineOutput -[[autodoc]] WuerstchenPriorPipelineOutput +[[autodoc]] pipelines.wuerstchen.pipeline_wuerstchen_prior.WuerstchenPriorPipelineOutput ## WuerstchenDecoderPipeline From cdf5109aa0304299c73ce69fc89942a5ed60c7d6 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 18 Aug 2023 22:58:50 +0200 Subject: [PATCH 132/181] num_inference_steps is int --- .../pipelines/wuerstchen/pipeline_wuerstchen.py | 2 +- .../wuerstchen/pipeline_wuerstchen_combined.py | 10 +++++----- .../pipelines/wuerstchen/pipeline_wuerstchen_prior.py | 2 +- tests/pipelines/wuerstchen/test_wuerstchen_decoder.py | 2 +- tests/pipelines/wuerstchen/test_wuerstchen_prior.py | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 00ca503fa0b2..1b834d9d3ff0 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -233,7 +233,7 @@ def __call__( image_embeds: torch.Tensor, prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, - num_inference_steps: Dict[float, int] = {0.0: 12}, + num_inference_steps: Union[Dict[float, int], int] = 12, guidance_scale: float = 0.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index aa4666eefc8d..32b43bf4e23c 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import torch from transformers import CLIPTextModel, CLIPTokenizer @@ -147,13 +147,13 @@ def __call__( self, prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, - num_inference_steps: int = 100, + num_inference_steps: int = 12, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, height: int = 512, width: int = 512, prior_guidance_scale: float = 4.0, - prior_num_inference_steps: int = 25, + prior_num_inference_steps: Union[int, Dict[float, int]] = {2 / 3: 20, 0.0: 10}, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", @@ -172,7 +172,7 @@ def __call__( if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - num_inference_steps (`int`, *optional*, defaults to 100): + num_inference_steps (`int`, *optional*, defaults to 12): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. height (`int`, *optional*, defaults to 512): @@ -185,7 +185,7 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - prior_num_inference_steps (`int`, *optional*, defaults to 100): + prior_num_inference_steps (`int`, *optional*, defaults to 30): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index ad0def8cba26..99193c4973e0 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -209,7 +209,7 @@ def __call__( prompt: Union[str, List[str]] = None, height: int = 1024, width: int = 1024, - num_inference_steps: Dict[float, int] = {2 / 3: 20, 0.0: 10}, + num_inference_steps: Union[int, Dict[float, int]] = {2 / 3: 20, 0.0: 10}, guidance_scale: float = 8.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py index fad16819604a..237e8ea76432 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py @@ -133,7 +133,7 @@ def get_dummy_inputs(self, device, seed=0): "prompt": "horse", "generator": generator, "guidance_scale": 1.0, - "num_inference_steps": {0.0: 2}, + "num_inference_steps": 2, "output_type": "np", } return inputs diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py index 254222bdf43c..c8f2acbccd7c 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py @@ -124,7 +124,7 @@ def get_dummy_inputs(self, device, seed=0): "prompt": "horse", "generator": generator, "guidance_scale": 4.0, - "num_inference_steps": {0.0: 2}, + "num_inference_steps": 2, "output_type": "np", } return inputs From 7a24a7d00b4bf3fe4e973fcc68620b01c64c4bf4 Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Mon, 21 Aug 2023 12:31:06 -0400 Subject: [PATCH 133/181] remove comments --- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index ad0def8cba26..06d530e56b47 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -235,9 +235,7 @@ def __call__( ) dtype = text_encoder_hidden_states.dtype - # latent_height = int(self.multiple * (height / self.multiple) / (1024 / 24)) latent_height = ceil(height / 42.67) - # latent_width = int(self.multiple * (width / self.multiple) / (1024 / 24)) latent_width = ceil(width / 42.67) num_channels = self.prior.config.c_in effnet_features_shape = (num_images_per_prompt * batch_size, num_channels, latent_height, latent_width) From 11cb295c55a785fc4c2dda34753788ec54257a29 Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Tue, 22 Aug 2023 16:43:45 -0400 Subject: [PATCH 134/181] fix example in docs --- docs/source/en/api/pipelines/wuerstchen.md | 48 ++++++++++------------ 1 file changed, 21 insertions(+), 27 deletions(-) diff --git a/docs/source/en/api/pipelines/wuerstchen.md b/docs/source/en/api/pipelines/wuerstchen.md index e391a23d92c1..02a1c8284038 100644 --- a/docs/source/en/api/pipelines/wuerstchen.md +++ b/docs/source/en/api/pipelines/wuerstchen.md @@ -17,77 +17,71 @@ After the initial paper release, we have improved numerous things in the archite ## Text-to-Image Generation -For the sake of explanation, since the model consists of different stages we will perform generation manually as: +For the sake of usability Würstchen can be used with a single pipeline. This pipeline is called `WuerstchenPipeline` and can be used as follows: ```python import torch -from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline +from diffusers import WuerstchenPipeline device = "cuda" dtype = torch.float16 num_images_per_prompt = 2 -prior_pipeline = WuerstchenPriorPipeline.from_pretrained( - "warp-diffusion/WuerstchenPriorPipeline", torch_dtype=dtype -).to(device) -generator_pipeline = WuerstchenDecoderPipeline.from_pretrained( - "warp-diffusion/WuerstchenDecoderPipeline", torch_dtype=dtype +pipeline = WuerstchenPipeline.from_pretrained( + "warp-diffusion/WuerstchenPipeline", torch_dtype=dtype ).to(device) caption = "A captivating artwork of a mysterious stone golem" negative_prompt = "bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated" -prior_output = prior_pipeline( +output = pipeline( prompt=caption, height=1024, width=1024, negative_prompt=negative_prompt, guidance_scale=8.0, num_images_per_prompt=num_images_per_prompt, -) -generator_output = generator_pipeline( - predicted_image_embeddings=prior_output.image_embeds, - prompt=caption, - negative_prompt=negative_prompt, - num_images_per_prompt=num_images_per_prompt, - guidance_scale=0.0, output_type="pil", ).images - ``` -## Pipeline Explained - -Würstchen consists out of 3 stages: Stage C, Stage B, Stage A. They all have different jobs and work only together. When generating images conditioned on text, Stage C will first generate the latents in a very compressed latent space. This is what happens in the `prior_pipeline`. Afterwards, the generated latents will be passed to Stage B, which decompresses the latents into a bigger latent space of a VQGAN. These latents can then be decoded by Stage A, which is a VQGAN, into the pixel-space. Stage B & Stage A both happen in the `generator_pipeline`. For more details, take a look the [paper](https://huggingface.co/papers/2306.00637). - -## Combined Pipeline - -For the sake of usability we have combined the two pipelines into one. This pipeline is called `WuerstchenPipeline` and can be used as follows: +For explanation purposes, we can also initialize the two main pipelines of Würstchen individually. Würstchen consists of 3 stages: Stage C, Stage B, Stage A. They all have different jobs and work only together. When generating text-conditional images, Stage C will first generate the latents in a very compressed latent space. This is what happens in the `prior_pipeline`. Afterwards, the generated latents will be passed to Stage B, which decompresses the latents into a bigger latent space of a VQGAN. These latents can then be decoded by Stage A, which is a VQGAN, into the pixel-space. Stage B & Stage A are both encapsulated in the `decoder_pipeline`. For more details, take a look the [paper](https://huggingface.co/papers/2306.00637). ```python import torch -from diffusers import WuerstchenPipeline +from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline device = "cuda" dtype = torch.float16 num_images_per_prompt = 2 -pipeline = WuerstchenPipeline.from_pretrained( - "warp-diffusion/WuerstchenPipeline", torch_dtype=dtype +prior_pipeline = WuerstchenPriorPipeline.from_pretrained( + "warp-diffusion/WuerstchenPriorPipeline", torch_dtype=dtype +).to(device) +decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained( + "warp-diffusion/WuerstchenDecoderPipeline", torch_dtype=dtype ).to(device) caption = "A captivating artwork of a mysterious stone golem" negative_prompt = "bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated" -output = pipeline( +prior_output = prior_pipeline( prompt=caption, height=1024, width=1024, negative_prompt=negative_prompt, guidance_scale=8.0, num_images_per_prompt=num_images_per_prompt, +) +decoder_output = decoder_pipeline( + predicted_image_embeddings=prior_output.image_embeds, + prompt=caption, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + guidance_scale=0.0, output_type="pil", ).images + ``` The original codebase, as well as experimental ideas, can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen). From 23e8740fdf90317d9e29b17ac2c548ccf6280210 Mon Sep 17 00:00:00 2001 From: Dominic Rampas <61938694+dome272@users.noreply.github.com> Date: Tue, 22 Aug 2023 20:52:41 -0400 Subject: [PATCH 135/181] Update src/diffusers/pipelines/wuerstchen/diffnext.py Co-authored-by: Patrick von Platen --- src/diffusers/pipelines/wuerstchen/diffnext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wuerstchen/diffnext.py b/src/diffusers/pipelines/wuerstchen/diffnext.py index 3eeb6e6662d4..4d5a4f6b63ba 100644 --- a/src/diffusers/pipelines/wuerstchen/diffnext.py +++ b/src/diffusers/pipelines/wuerstchen/diffnext.py @@ -59,7 +59,7 @@ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0): elif block_type == "T": return TimestepBlock(c_hidden, c_r) else: - raise Exception(f"Block type {block_type} not supported") + raise ValueError(f"Block type {block_type} not supported") # BLOCKS # -- down blocks From cc70ca52b6d9fccc0a9c2e4d2a3cfb444ab4297a Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Tue, 22 Aug 2023 22:19:04 -0400 Subject: [PATCH 136/181] rename layernorm to WuerstchenLayerNorm --- src/diffusers/pipelines/wuerstchen/common.py | 8 ++++---- src/diffusers/pipelines/wuerstchen/diffnext.py | 10 +++++----- src/diffusers/pipelines/wuerstchen/wuerstchen_prior.py | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/common.py b/src/diffusers/pipelines/wuerstchen/common.py index d2650110eb67..7c6568a3d193 100644 --- a/src/diffusers/pipelines/wuerstchen/common.py +++ b/src/diffusers/pipelines/wuerstchen/common.py @@ -4,7 +4,7 @@ from ...models.attention_processor import Attention -class LayerNorm2d(nn.LayerNorm): +class WuerstchenLayerNorm(nn.LayerNorm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -28,7 +28,7 @@ class ResBlockStageB(nn.Module): def __init__(self, c, c_skip=None, kernel_size=3, dropout=0.0): super().__init__() self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) - self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) self.channelwise = nn.Sequential( nn.Linear(c + c_skip, c * 4), nn.GELU(), @@ -50,7 +50,7 @@ class ResBlock(nn.Module): def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): super().__init__() self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) - self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) self.channelwise = nn.Sequential( nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c) ) @@ -81,7 +81,7 @@ class AttnBlock(nn.Module): def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): super().__init__() self.self_attn = self_attn - self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True) self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c)) diff --git a/src/diffusers/pipelines/wuerstchen/diffnext.py b/src/diffusers/pipelines/wuerstchen/diffnext.py index 3eeb6e6662d4..a52f2209af15 100644 --- a/src/diffusers/pipelines/wuerstchen/diffnext.py +++ b/src/diffusers/pipelines/wuerstchen/diffnext.py @@ -6,7 +6,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models.modeling_utils import ModelMixin -from .common import AttnBlock, LayerNorm2d, ResBlockStageB, TimestepBlock +from .common import AttnBlock, WuerstchenLayerNorm, ResBlockStageB, TimestepBlock class DiffNeXt(ModelMixin, ConfigMixin): @@ -48,7 +48,7 @@ def __init__( self.embedding = nn.Sequential( nn.PixelUnshuffle(patch_size), nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1), - LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), + WuerstchenLayerNorm(c_hidden[0], elementwise_affine=False, eps=1e-6), ) def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0): @@ -69,7 +69,7 @@ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0): if i > 0: down_block.append( nn.Sequential( - LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), + WuerstchenLayerNorm(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2), ) ) @@ -91,7 +91,7 @@ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0): if i > 0: up_block.append( nn.Sequential( - LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6), + WuerstchenLayerNorm(c_hidden[i], elementwise_affine=False, eps=1e-6), nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2), ) ) @@ -99,7 +99,7 @@ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0): # OUTPUT self.clf = nn.Sequential( - LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), + WuerstchenLayerNorm(c_hidden[0], elementwise_affine=False, eps=1e-6), nn.Conv2d(c_hidden[0], 2 * c_out * (patch_size**2), kernel_size=1), nn.PixelShuffle(patch_size), ) diff --git a/src/diffusers/pipelines/wuerstchen/wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/wuerstchen_prior.py index 462851d08ff0..f289ebd09e8c 100644 --- a/src/diffusers/pipelines/wuerstchen/wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/wuerstchen_prior.py @@ -5,7 +5,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models.modeling_utils import ModelMixin -from .common import AttnBlock, LayerNorm2d, ResBlock, TimestepBlock +from .common import AttnBlock, WuerstchenLayerNorm, ResBlock, TimestepBlock class WuerstchenPrior(ModelMixin, ConfigMixin): @@ -26,7 +26,7 @@ def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dro self.blocks.append(TimestepBlock(c, c_r)) self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout)) self.out = nn.Sequential( - LayerNorm2d(c, elementwise_affine=False, eps=1e-6), + WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6), nn.Conv2d(c, c_in * 2, kernel_size=1), ) From 851115c0eff51e1bc73b473d949c2eb444f710ec Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 24 Aug 2023 17:53:10 +0200 Subject: [PATCH 137/181] rename DiffNext to WuerstchenDiffNeXt --- scripts/convert_wuerstchen.py | 4 +- .../pipelines/wuerstchen/__init__.py | 4 +- ...ommon.py => modeling_wuerstchen_common.py} | 22 ---------- ...ext.py => modeling_wuerstchen_diffnext.py} | 44 ++++++++++++++----- ..._prior.py => modeling_wuerstchen_prior.py} | 3 +- .../wuerstchen/pipeline_wuerstchen.py | 8 ++-- .../pipeline_wuerstchen_combined.py | 8 ++-- .../wuerstchen/pipeline_wuerstchen_prior.py | 2 +- .../wuerstchen/test_wuerstchen_decoder.py | 4 +- 9 files changed, 50 insertions(+), 49 deletions(-) rename src/diffusers/pipelines/wuerstchen/{common.py => modeling_wuerstchen_common.py} (77%) rename src/diffusers/pipelines/wuerstchen/{diffnext.py => modeling_wuerstchen_diffnext.py} (89%) rename src/diffusers/pipelines/wuerstchen/{wuerstchen_prior.py => modeling_wuerstchen_prior.py} (94%) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 3bceb51968d1..396c29f6b1df 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -12,7 +12,7 @@ WuerstchenPipeline, WuerstchenPriorPipeline, ) -from diffusers.pipelines.wuerstchen import DiffNeXt, WuerstchenPrior +from diffusers.pipelines.wuerstchen import WuerstchenDiffNeXt, WuerstchenPrior model_path = "models/" @@ -56,7 +56,7 @@ state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights else: state_dict[key] = orig_state_dict[key] -generator = DiffNeXt() +generator = WuerstchenDiffNeXt() generator.load_state_dict(state_dict) # Prior diff --git a/src/diffusers/pipelines/wuerstchen/__init__.py b/src/diffusers/pipelines/wuerstchen/__init__.py index 2403ca6db8f9..866d394826f8 100644 --- a/src/diffusers/pipelines/wuerstchen/__init__.py +++ b/src/diffusers/pipelines/wuerstchen/__init__.py @@ -2,8 +2,8 @@ if is_transformers_available() and is_torch_available(): - from .diffnext import DiffNeXt + from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt + from .modeling_wuerstchen_prior import WuerstchenPrior from .pipeline_wuerstchen import WuerstchenDecoderPipeline from .pipeline_wuerstchen_combined import WuerstchenPipeline from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline - from .wuerstchen_prior import WuerstchenPrior diff --git a/src/diffusers/pipelines/wuerstchen/common.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py similarity index 77% rename from src/diffusers/pipelines/wuerstchen/common.py rename to src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py index 7c6568a3d193..d7bde3f8e2d0 100644 --- a/src/diffusers/pipelines/wuerstchen/common.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py @@ -24,28 +24,6 @@ def forward(self, x, t): return x * (1 + a) + b -class ResBlockStageB(nn.Module): - def __init__(self, c, c_skip=None, kernel_size=3, dropout=0.0): - super().__init__() - self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) - self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) - self.channelwise = nn.Sequential( - nn.Linear(c + c_skip, c * 4), - nn.GELU(), - GlobalResponseNorm(c * 4), - nn.Dropout(dropout), - nn.Linear(c * 4, c), - ) - - def forward(self, x, x_skip=None): - x_res = x - x = self.norm(self.depthwise(x)) - if x_skip is not None: - x = torch.cat([x, x_skip], dim=1) - x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - return x + x_res - - class ResBlock(nn.Module): def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): super().__init__() diff --git a/src/diffusers/pipelines/wuerstchen/diffnext.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py similarity index 89% rename from src/diffusers/pipelines/wuerstchen/diffnext.py rename to src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py index 192e26f59678..8a4e4348caf6 100644 --- a/src/diffusers/pipelines/wuerstchen/diffnext.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py @@ -6,10 +6,10 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models.modeling_utils import ModelMixin -from .common import AttnBlock, WuerstchenLayerNorm, ResBlockStageB, TimestepBlock +from .modeling_wuerstchen_common import AttnBlock, GlobalResponseNorm, TimestepBlock, WuerstchenLayerNorm -class DiffNeXt(ModelMixin, ConfigMixin): +class WuerstchenDiffNeXt(ModelMixin, ConfigMixin): @register_to_config def __init__( self, @@ -105,7 +105,15 @@ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0): ) # --- WEIGHT INIT --- - self.apply(self._init_weights) # General init + self.apply(self._init_weights) + + def _init_weights(self, m): + # General init + if isinstance(m, (nn.Conv2d, nn.Linear)): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + for mapper in self.effnet_mappers: if mapper is not None: nn.init.normal_(mapper.weight, std=0.02) # conditionings @@ -117,16 +125,10 @@ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0): for level_block in self.down_blocks + self.up_blocks: for block in level_block: if isinstance(block, ResBlockStageB): - block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks)) + block.channelwise[-1].weight.data *= np.sqrt(1 / sum(self.config.blocks)) elif isinstance(block, TimestepBlock): nn.init.constant_(block.mapper.weight, 0) - def _init_weights(self, m): - if isinstance(m, (nn.Conv2d, nn.Linear)): - nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - def gen_r_embedding(self, r, max_positions=10000): r = r * max_positions half_dim = self.c_r // 2 @@ -214,3 +216,25 @@ def forward(self, x, r, effnet, clip=None, x_cat=None, eps=1e-3, return_noise=Tr return (x_in - a) / b else: return a, b + + +class ResBlockStageB(nn.Module): + def __init__(self, c, c_skip=None, kernel_size=3, dropout=0.0): + super().__init__() + self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c + c_skip, c * 4), + nn.GELU(), + GlobalResponseNorm(c * 4), + nn.Dropout(dropout), + nn.Linear(c * 4, c), + ) + + def forward(self, x, x_skip=None): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res diff --git a/src/diffusers/pipelines/wuerstchen/wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py similarity index 94% rename from src/diffusers/pipelines/wuerstchen/wuerstchen_prior.py rename to src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index f289ebd09e8c..b4ad701e1806 100644 --- a/src/diffusers/pipelines/wuerstchen/wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -5,7 +5,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models.modeling_utils import ModelMixin -from .common import AttnBlock, WuerstchenLayerNorm, ResBlock, TimestepBlock +from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm class WuerstchenPrior(ModelMixin, ConfigMixin): @@ -54,5 +54,4 @@ def forward(self, x, r, c): else: x = block(x) a, b = self.out(x).chunk(2, dim=1) - # denoised = a / (1-(1-b).pow(2)).sqrt() return (x_in - a) / ((1 - b).abs() + 1e-5) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 1b834d9d3ff0..6e4da2635424 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -22,7 +22,7 @@ from ...schedulers import DDPMWuerstchenScheduler from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from .diffnext import DiffNeXt +from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -59,8 +59,8 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): The CLIP tokenizer. text_encoder (`CLIPTextModel`): The CLIP text encoder. - generator ([`DiffNeXt`]): - The DiffNeXt unet generator. + generator ([`WuerstchenDiffNeXt`]): + The WuerstchenDiffNeXt unet generator. vqgan ([`VQModelPaella`]): The VQGAN model. scheduler ([`DDPMWuerstchenScheduler`]): @@ -73,7 +73,7 @@ def __init__( self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, - generator: DiffNeXt, + generator: WuerstchenDiffNeXt, scheduler: DDPMWuerstchenScheduler, vqgan: VQModelPaella, latent_dim_scale: float = 10.67, diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index 32b43bf4e23c..bfa551ced1cb 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -20,10 +20,10 @@ from ...schedulers import DDPMWuerstchenScheduler from ...utils import replace_example_docstring from ..pipeline_utils import DiffusionPipeline -from .diffnext import DiffNeXt +from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt +from .modeling_wuerstchen_prior import WuerstchenPrior from .pipeline_wuerstchen import WuerstchenDecoderPipeline from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline -from .wuerstchen_prior import WuerstchenPrior TEXT2IMAGE_EXAMPLE_DOC_STRING = """ @@ -52,7 +52,7 @@ class WuerstchenPipeline(DiffusionPipeline): The decoder tokenizer to be used for text inputs. text_encoder (:class:`~transformers.CLIPTextModel`): The decoder text encoder to be used for text inputs. - generator (:class:`~diffusions.models.DiffNeXt`): + generator (:class:`~diffusions.models.WuerstchenDiffNeXt`): The generator model to be used for decoder image generation pipeline. scheduler (:class:`~diffusions.schedulers.DDPMWuerstchenScheduler`): The scheduler to be used for decoder image generation pipeline. @@ -74,7 +74,7 @@ def __init__( self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, - generator: DiffNeXt, + generator: WuerstchenDiffNeXt, scheduler: DDPMWuerstchenScheduler, vqgan: VQModelPaella, prior_tokenizer: CLIPTokenizer, diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 61fd8f3df8b5..7b9abc2e251e 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -23,7 +23,7 @@ from ...schedulers import DDPMWuerstchenScheduler from ...utils import BaseOutput, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline -from .wuerstchen_prior import WuerstchenPrior +from .modeling_wuerstchen_prior import WuerstchenPrior logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py index 237e8ea76432..0c01ae628dac 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py @@ -21,7 +21,7 @@ from diffusers import DDPMWuerstchenScheduler, WuerstchenDecoderPipeline from diffusers.models import VQModelPaella -from diffusers.pipelines.wuerstchen import DiffNeXt +from diffusers.pipelines.wuerstchen import WuerstchenDiffNeXt from diffusers.utils import torch_device from diffusers.utils.testing_utils import enable_full_determinism, skip_mps @@ -102,7 +102,7 @@ def dummy_generator(self): "c_in": 4, } - model = DiffNeXt(**model_kwargs) + model = WuerstchenDiffNeXt(**model_kwargs) return model.eval() def get_dummy_components(self): From ecd6ab35baad4f3d5137524ade064c7e89aa8886 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 25 Aug 2023 10:48:14 +0200 Subject: [PATCH 138/181] added comment about MixingResidualBlock --- src/diffusers/models/resnet.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 4e4218a96f16..5db7b2a269e2 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -675,7 +675,11 @@ def forward(self, inputs): class MixingResidualBlock(nn.Module): - def __init__(self, inp_channels, c_hidden): + """ + Residual block with mixing used by Paella's VQ-VAE. + """ + + def __init__(self, inp_channels, embed_dim): super().__init__() # depthwise self.norm1 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6) @@ -686,9 +690,7 @@ def __init__(self, inp_channels, c_hidden): # channelwise self.norm2 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6) self.channelwise = nn.Sequential( - nn.Linear(inp_channels, c_hidden), - nn.GELU(), - nn.Linear(c_hidden, inp_channels), + nn.Linear(inp_channels, embed_dim), nn.GELU(), nn.Linear(embed_dim, inp_channels) ) self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) From 44ad4c38450d91f8f4fbc0fec5986ac7a2c83487 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 28 Aug 2023 11:46:34 +0200 Subject: [PATCH 139/181] move paella vq-vae to pipelines' folder --- scripts/convert_wuerstchen.py | 5 +- src/diffusers/__init__.py | 1 - src/diffusers/models/__init__.py | 1 - src/diffusers/models/resnet.py | 30 ----------- .../pipelines/wuerstchen/__init__.py | 1 + .../wuerstchen/modeling_paella_vq_model.py} | 53 +++++++++++++++---- .../wuerstchen/modeling_wuerstchen_common.py | 15 ++++++ .../modeling_wuerstchen_diffnext.py | 15 ++++++ .../wuerstchen/modeling_wuerstchen_prior.py | 15 ++++++ .../wuerstchen/pipeline_wuerstchen.py | 8 +-- .../pipeline_wuerstchen_combined.py | 6 +-- src/diffusers/utils/dummy_pt_objects.py | 15 ------ 12 files changed, 99 insertions(+), 66 deletions(-) rename src/diffusers/{models/vq_paella.py => pipelines/wuerstchen/modeling_paella_vq_model.py} (75%) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 396c29f6b1df..fc23fc4f2250 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -7,12 +7,11 @@ from diffusers import ( DDPMWuerstchenScheduler, - VQModelPaella, WuerstchenDecoderPipeline, WuerstchenPipeline, WuerstchenPriorPipeline, ) -from diffusers.pipelines.wuerstchen import WuerstchenDiffNeXt, WuerstchenPrior +from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior model_path = "models/" @@ -24,7 +23,7 @@ state_dict["vquantizer.embedding.weight"] = state_dict["vquantizer.codebook.weight"] state_dict.pop("vquantizer.codebook.weight") -vqmodel = VQModelPaella(num_vq_embeddings=paella_vqmodel.codebook_size, latent_channels=paella_vqmodel.c_latent) +vqmodel = PaellaVQModel(num_vq_embeddings=paella_vqmodel.codebook_size, latent_channels=paella_vqmodel.c_latent) vqmodel.load_state_dict(state_dict) # Clip Text encoder and tokenizer diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 7f3a465af9e3..641539919c95 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -51,7 +51,6 @@ UNet2DModel, UNet3DConditionModel, VQModel, - VQModelPaella, ) from .optimization import ( get_constant_schedule, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index b8653fe8b97b..54e77df0ff72 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -31,7 +31,6 @@ from .unet_2d_condition import UNet2DConditionModel from .unet_3d_condition import UNet3DConditionModel from .vq_model import VQModel - from .vq_paella import VQModelPaella if is_flax_available(): from .controlnet_flax import FlaxControlNetModel diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 5db7b2a269e2..72aa17ed2c2d 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -674,36 +674,6 @@ def forward(self, inputs): return output -class MixingResidualBlock(nn.Module): - """ - Residual block with mixing used by Paella's VQ-VAE. - """ - - def __init__(self, inp_channels, embed_dim): - super().__init__() - # depthwise - self.norm1 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6) - self.depthwise = nn.Sequential( - nn.ReplicationPad2d(1), nn.Conv2d(inp_channels, inp_channels, kernel_size=3, groups=inp_channels) - ) - - # channelwise - self.norm2 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6) - self.channelwise = nn.Sequential( - nn.Linear(inp_channels, embed_dim), nn.GELU(), nn.Linear(embed_dim, inp_channels) - ) - - self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) - - def forward(self, x): - mods = self.gammas - x_temp = self.norm1(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[0]) + mods[1] - x = x + self.depthwise(x_temp) * mods[2] - x_temp = self.norm2(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[3]) + mods[4] - x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] - return x - - # unet_rl.py class ResidualTemporalBlock1D(nn.Module): def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): diff --git a/src/diffusers/pipelines/wuerstchen/__init__.py b/src/diffusers/pipelines/wuerstchen/__init__.py index 866d394826f8..998d48689994 100644 --- a/src/diffusers/pipelines/wuerstchen/__init__.py +++ b/src/diffusers/pipelines/wuerstchen/__init__.py @@ -2,6 +2,7 @@ if is_transformers_available() and is_torch_available(): + from .modeling_paella_vq_model import PaellaVQModel from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt from .modeling_wuerstchen_prior import WuerstchenPrior from .pipeline_wuerstchen import WuerstchenDecoderPipeline diff --git a/src/diffusers/models/vq_paella.py b/src/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py similarity index 75% rename from src/diffusers/models/vq_paella.py rename to src/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py index 4f35928e17fe..e33d2e576c01 100644 --- a/src/diffusers/models/vq_paella.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py @@ -1,3 +1,4 @@ +# Copyright (c) 2022 Dominic Rampas MIT License # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,19 +12,49 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import Union import torch import torch.nn as nn -from ..configuration_utils import ConfigMixin, register_to_config -from .modeling_utils import ModelMixin -from .resnet import MixingResidualBlock -from .vae import DecoderOutput, VectorQuantizer -from .vq_model import VQEncoderOutput +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin +from ...models.vae import DecoderOutput, VectorQuantizer +from ...models.vq_model import VQEncoderOutput + + +class MixingResidualBlock(nn.Module): + """ + Residual block with mixing used by Paella's VQ-VAE. + """ + + def __init__(self, inp_channels, embed_dim): + super().__init__() + # depthwise + self.norm1 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6) + self.depthwise = nn.Sequential( + nn.ReplicationPad2d(1), nn.Conv2d(inp_channels, inp_channels, kernel_size=3, groups=inp_channels) + ) + + # channelwise + self.norm2 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(inp_channels, embed_dim), nn.GELU(), nn.Linear(embed_dim, inp_channels) + ) + + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) + def forward(self, x): + mods = self.gammas + x_temp = self.norm1(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[0]) + mods[1] + x = x + self.depthwise(x_temp) * mods[2] + x_temp = self.norm2(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[3]) + mods[4] + x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] + return x -class VQModelPaella(ModelMixin, ConfigMixin): + +class PaellaVQModel(ModelMixin, ConfigMixin): r"""VQ-VAE model from Paella model. This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library @@ -57,11 +88,11 @@ def __init__( super().__init__() c_levels = [embed_dim // (2**i) for i in reversed(range(levels))] + # Encoder blocks self.in_block = nn.Sequential( nn.PixelUnshuffle(up_down_scale_factor), nn.Conv2d(in_channels * up_down_scale_factor**2, c_levels[0], kernel_size=1), ) - down_blocks = [] for i in range(levels): if i > 0: @@ -75,6 +106,8 @@ def __init__( ) ) self.down_blocks = nn.Sequential(*down_blocks) + + # Vector Quantizer self.vquantizer = VectorQuantizer(num_vq_embeddings, vq_embed_dim=latent_channels, legacy=False, beta=0.25) # Decoder blocks @@ -97,7 +130,7 @@ def __init__( def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: h = self.in_block(x) - h = self.down_blocks(h) / self.config.scale_factor + h = self.down_blocks(h) if not return_dict: return (h,) @@ -108,9 +141,9 @@ def decode( self, h: torch.FloatTensor, force_not_quantize: bool = True, return_dict: bool = True ) -> Union[DecoderOutput, torch.FloatTensor]: if not force_not_quantize: - quant, _, _ = self.vquantizer(h * self.config.scale_factor) + quant, _, _ = self.vquantizer(h) else: - quant = h * self.config.scale_factor + quant = h x = self.up_blocks(quant) dec = self.out_block(x) diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py index d7bde3f8e2d0..b3aac39386bc 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py @@ -1,3 +1,18 @@ +# Copyright (c) 2023 Dominic Rampas MIT License +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch import torch.nn as nn diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py index 8a4e4348caf6..fe862a254834 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py @@ -1,3 +1,18 @@ +# Copyright (c) 2023 Dominic Rampas MIT License +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import numpy as np diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index b4ad701e1806..9bd29b59b3af 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -1,3 +1,18 @@ +# Copyright (c) 2023 Dominic Rampas MIT License +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import torch diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 6e4da2635424..81003c4f829a 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -18,10 +18,10 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer -from ...models import VQModelPaella from ...schedulers import DDPMWuerstchenScheduler from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .modeling_paella_vq_model import PaellaVQModel from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt @@ -61,7 +61,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): The CLIP text encoder. generator ([`WuerstchenDiffNeXt`]): The WuerstchenDiffNeXt unet generator. - vqgan ([`VQModelPaella`]): + vqgan ([`PaellaVQModel`]): The VQGAN model. scheduler ([`DDPMWuerstchenScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. @@ -75,7 +75,7 @@ def __init__( text_encoder: CLIPTextModel, generator: WuerstchenDiffNeXt, scheduler: DDPMWuerstchenScheduler, - vqgan: VQModelPaella, + vqgan: PaellaVQModel, latent_dim_scale: float = 10.67, ) -> None: super().__init__() @@ -294,6 +294,8 @@ def __call__( generator=generator, ).prev_sample + # scale and decode the image latents with vq-vae + latents = self.vqgan.config.scaling_factor * latents images = self.vqgan.decode(latents).sample.clamp(0, 1) if output_type not in ["pt", "np", "pil"]: diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index bfa551ced1cb..e187464208eb 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -16,10 +16,10 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer -from ...models import VQModelPaella from ...schedulers import DDPMWuerstchenScheduler from ...utils import replace_example_docstring from ..pipeline_utils import DiffusionPipeline +from .modeling_paella_vq_model import PaellaVQModel from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt from .modeling_wuerstchen_prior import WuerstchenPrior from .pipeline_wuerstchen import WuerstchenDecoderPipeline @@ -56,7 +56,7 @@ class WuerstchenPipeline(DiffusionPipeline): The generator model to be used for decoder image generation pipeline. scheduler (:class:`~diffusions.schedulers.DDPMWuerstchenScheduler`): The scheduler to be used for decoder image generation pipeline. - vqgan (:class:`~diffusions.models.VQModelPaella`): + vqgan (:class:`~diffusions.pipelines.wuerstchen.modeling_paella_vq_model.PaellaVQModel`): The VQGAN model to be used for decoder image generation pipeline. prior_tokenizer (:class:`~transformers.CLIPTokenizer`): The prior tokenizer to be used for text inputs. @@ -76,7 +76,7 @@ def __init__( text_encoder: CLIPTextModel, generator: WuerstchenDiffNeXt, scheduler: DDPMWuerstchenScheduler, - vqgan: VQModelPaella, + vqgan: PaellaVQModel, prior_tokenizer: CLIPTokenizer, prior_text_encoder: CLIPTextModel, prior_prior: WuerstchenPrior, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index e16499c73a83..8426db53eb42 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -227,21 +227,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class VQModelPaella(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - def get_constant_schedule(*args, **kwargs): requires_backends(get_constant_schedule, ["torch"]) From d3e591973a518172a9e554c9403bf80b7938a4a3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 28 Aug 2023 12:47:06 +0200 Subject: [PATCH 140/181] initial decoder test --- .../schedulers/scheduling_ddpm_wuerstchen.py | 1 + .../wuerstchen/test_wuerstchen_decoder.py | 20 ++++++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py index 40140942ea32..4bd4d51b0c66 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py +++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py @@ -1,3 +1,4 @@ +# Copyright (c) 2022 Pablo Pernías MIT License # Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py index 0c01ae628dac..94ed2e88ac5f 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py @@ -20,8 +20,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import DDPMWuerstchenScheduler, WuerstchenDecoderPipeline -from diffusers.models import VQModelPaella -from diffusers.pipelines.wuerstchen import WuerstchenDiffNeXt +from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt from diffusers.utils import torch_device from diffusers.utils.testing_utils import enable_full_determinism, skip_mps @@ -90,8 +89,11 @@ def dummy_vqgan(self): model_kwargs = { "in_channels": 3, + "embed_dim": 2, + "bottleneck_blocks": 1, + "num_vq_embeddings": 2, } - model = VQModelPaella(**model_kwargs) + model = PaellaVQModel(**model_kwargs) return model.eval() @property @@ -99,7 +101,15 @@ def dummy_generator(self): torch.manual_seed(0) model_kwargs = { - "c_in": 4, + "c_in": 1, + "c_cond": 1, + "c_r": 1, + "c_hidden": [2], + "effnet_embd": 1, + "nhead": [1], + "blocks": [1], + "level_config": ["CT"], + "clip_embd": self.text_embedder_hidden_size, } model = WuerstchenDiffNeXt(**model_kwargs) @@ -129,7 +139,7 @@ def get_dummy_inputs(self, device, seed=0): else: generator = torch.Generator(device=device).manual_seed(seed) inputs = { - "image_embeds": torch.ones(1, 3, 32, 32, device=device), + "image_embeds": torch.ones((1, 16, 21, 21), device=device), "prompt": "horse", "generator": generator, "guidance_scale": 1.0, From 2f325e6920160fbc7bc2edd34e99871a02cb0871 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 28 Aug 2023 15:47:51 +0200 Subject: [PATCH 141/181] increased test_float16_inference expected diff --- tests/pipelines/wuerstchen/test_wuerstchen_prior.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py index c8f2acbccd7c..42f9c24f996b 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py @@ -187,3 +187,7 @@ def test_attention_slicing_forward_pass(self): test_max_difference=test_max_difference, test_mean_pixel_difference=test_mean_pixel_difference, ) + + @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + def test_float16_inference(self): + super().test_float16_inference(expected_max_diff=1.1) From db8fae2789a720433c37700812fbed008c45e658 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 28 Aug 2023 17:42:51 +0200 Subject: [PATCH 142/181] self_attn is always true --- .../pipelines/wuerstchen/modeling_wuerstchen_diffnext.py | 3 +-- tests/pipelines/wuerstchen/test_wuerstchen_prior.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py index fe862a254834..d22eb7b7c991 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py @@ -42,7 +42,6 @@ def __init__( clip_embd=1024, kernel_size=3, dropout=0.1, - self_attn=True, ): super().__init__() self.c_r = c_r @@ -70,7 +69,7 @@ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0): if block_type == "C": return ResBlockStageB(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) elif block_type == "A": - return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) + return AttnBlock(c_hidden, c_cond, nhead, self_attn=True, dropout=dropout) elif block_type == "T": return TimestepBlock(c_hidden, c_r) else: diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py index 42f9c24f996b..656b0a923485 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py @@ -188,6 +188,6 @@ def test_attention_slicing_forward_pass(self): test_mean_pixel_difference=test_mean_pixel_difference, ) - @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + @unittest.skip(reason="flakey and float16 requires CUDA") def test_float16_inference(self): - super().test_float16_inference(expected_max_diff=1.1) + super().test_float16_inference() From 754b9ab8c6e2202341f2bd56b6d82f4b421d1cf5 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 30 Aug 2023 19:02:26 +0200 Subject: [PATCH 143/181] more passing decoder tests --- .../wuerstchen/pipeline_wuerstchen.py | 2 +- .../wuerstchen/test_wuerstchen_decoder.py | 54 ++++++++----------- 2 files changed, 22 insertions(+), 34 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 81003c4f829a..8cf2b6abd4e8 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -295,7 +295,7 @@ def __call__( ).prev_sample # scale and decode the image latents with vq-vae - latents = self.vqgan.config.scaling_factor * latents + latents = self.vqgan.config.scale_factor * latents images = self.vqgan.decode(latents).sample.clamp(0, 1) if output_type not in ["pt", "np", "pil"]: diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py index 94ed2e88ac5f..f572e852a8f3 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py @@ -73,6 +73,7 @@ def dummy_text_encoder(self): config = CLIPTextConfig( bos_token_id=0, eos_token_id=2, + projection_dim=self.text_embedder_hidden_size, hidden_size=self.text_embedder_hidden_size, intermediate_size=37, layer_norm_eps=1e-05, @@ -88,8 +89,6 @@ def dummy_vqgan(self): torch.manual_seed(0) model_kwargs = { - "in_channels": 3, - "embed_dim": 2, "bottleneck_blocks": 1, "num_vq_embeddings": 2, } @@ -101,15 +100,13 @@ def dummy_generator(self): torch.manual_seed(0) model_kwargs = { - "c_in": 1, - "c_cond": 1, - "c_r": 1, - "c_hidden": [2], - "effnet_embd": 1, - "nhead": [1], - "blocks": [1], + "c_cond": self.text_embedder_hidden_size, + "c_hidden": [320], + "nhead": [-1], + "blocks": [4], "level_config": ["CT"], "clip_embd": self.text_embedder_hidden_size, + "inject_effnet": [False], } model = WuerstchenDiffNeXt(**model_kwargs) @@ -139,7 +136,7 @@ def get_dummy_inputs(self, device, seed=0): else: generator = torch.Generator(device=device).manual_seed(seed) inputs = { - "image_embeds": torch.ones((1, 16, 21, 21), device=device), + "image_embeds": torch.ones((1, 16, 10, 10), device=device), "prompt": "horse", "generator": generator, "guidance_scale": 1.0, @@ -159,29 +156,16 @@ def test_wuerstchen_decoder(self): pipe.set_progress_bar_config(disable=None) output = pipe(**self.get_dummy_inputs(device)) - image = output.image_embeds - - image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0] - - image_slice = image[0, 0, 0, -10:] - image_from_tuple_slice = image_from_tuple[0, 0, 0, -10:] - - assert image.shape == (1, 2, 24, 24) - - expected_slice = np.array( - [ - -7172.9814, - -3438.9731, - -1093.4564, - 388.91516, - -7471.7383, - -7998.2944, - -5328.388, - 218.0543, - -2731.6716, - -8056.8545, - ], - ) + image = output.images + + image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False) + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 424, 424, 3) + + expected_slice = np.array([1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -206,3 +190,7 @@ def test_attention_slicing_forward_pass(self): test_max_difference=test_max_difference, test_mean_pixel_difference=test_mean_pixel_difference, ) + + @unittest.skip(reason="bf16 not supported and requires CUDA") + def test_float16_inference(self): + super().test_float16_inference() From f162d778390cbe77083646992ea028bfe53def18 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 31 Aug 2023 11:39:50 +0200 Subject: [PATCH 144/181] batch image_embeds --- tests/pipelines/wuerstchen/test_wuerstchen_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py index f572e852a8f3..bfa8f5b7d7f0 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py @@ -33,7 +33,7 @@ class WuerstchenDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = WuerstchenDecoderPipeline params = ["prompt"] - batch_params = ["prompt", "negative_prompt"] + batch_params = ["image_embeds", "prompt", "negative_prompt"] required_optional_params = [ "num_images_per_prompt", "generator", From c74f9c6fc240d106882203402f9a577c3dabda84 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 31 Aug 2023 11:57:54 +0200 Subject: [PATCH 145/181] fix failing tests --- .../wuerstchen/pipeline_wuerstchen.py | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 8cf2b6abd4e8..89bd51ae7d55 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -115,14 +115,7 @@ def enable_sequential_cpu_offload(self, gpu_id=0): if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - ): + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( @@ -215,22 +208,25 @@ def check_inputs( ) if do_classifier_free_guidance: - assert ( - predicted_image_embeddings.size(0) == text_encoder_hidden_states.size(0) // 2 - ), f"'text_encoder_hidden_states' must be double the size of 'predicted_image_embeddings' in the first dimension, but {predicted_image_embeddings.size(0)} != {text_encoder_hidden_states.size(0)}." - else: - if predicted_image_embeddings.size(0) * 2 == text_encoder_hidden_states.size(0): - text_encoder_hidden_states = text_encoder_hidden_states.chunk(2)[0] - assert predicted_image_embeddings.size(0) == text_encoder_hidden_states.size( - 0 - ), f"'text_encoder_hidden_states' must be the size of 'predicted_image_embeddings' in the first dimension, but {predicted_image_embeddings.size(0)} != {text_encoder_hidden_states.size(0)}." + if predicted_image_embeddings.size(0) != text_encoder_hidden_states.size(0) // 2: + raise ValueError( + f"'text_encoder_hidden_states' must be double the size of 'predicted_image_embeddings' in the first dimension, but {predicted_image_embeddings.size(0)} != {text_encoder_hidden_states.size(0)}." + ) + + # if predicted_image_embeddings.size(0) * 2 == text_encoder_hidden_states.size(0): + text_encoder_hidden_states = text_encoder_hidden_states.chunk(2)[0] + + if predicted_image_embeddings.size(0) != text_encoder_hidden_states.size(0): + raise ValueError( + f"'text_encoder_hidden_states' must be the size of 'predicted_image_embeddings' in the first dimension, but {predicted_image_embeddings.size(0)} != {text_encoder_hidden_states.size(0)}." + ) return predicted_image_embeddings, text_encoder_hidden_states @torch.no_grad() def __call__( self, - image_embeds: torch.Tensor, + image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, num_inference_steps: Union[Dict[float, int], int] = 12, @@ -244,6 +240,10 @@ def __call__( device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if isinstance(num_inference_steps, int): num_inference_steps = {0.0: num_inference_steps} From 741f6ef2d939c6a41249fd9357eece00040c3eaa Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 31 Aug 2023 12:20:35 +0200 Subject: [PATCH 146/181] set the correct dtype --- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 89bd51ae7d55..b010c7ebfbdc 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -259,7 +259,7 @@ def __call__( image_embeds, text_encoder_hidden_states, do_classifier_free_guidance, device ) - dtype = predicted_image_embeddings.dtype + dtype = self.generator.dtype latent_height = int(predicted_image_embeddings.size(2) * self.config.latent_dim_scale) latent_width = int(predicted_image_embeddings.size(3) * self.config.latent_dim_scale) latent_features_shape = (predicted_image_embeddings.size(0), 4, latent_height, latent_width) From 09781d2ac7f63bf70c4aaed6474827a655b3417d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 31 Aug 2023 13:30:39 +0200 Subject: [PATCH 147/181] relax inference test --- tests/pipelines/wuerstchen/test_wuerstchen_prior.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py index 656b0a923485..e42a01438f13 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py @@ -176,6 +176,7 @@ def test_inference_batch_single_identical(self): test_max_difference=test_max_difference, relax_max_difference=relax_max_difference, test_mean_pixel_difference=test_mean_pixel_difference, + expected_max_diff=1e-1, ) @skip_mps From 06cd467dc1594898c173519a8764c6381343cfe5 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 31 Aug 2023 17:14:50 +0200 Subject: [PATCH 148/181] update prior --- .../wuerstchen/test_wuerstchen_prior.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py index e42a01438f13..ac70cc0bea98 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py @@ -151,16 +151,16 @@ def test_wuerstchen_prior(self): expected_slice = np.array( [ - -7172.9814, - -3438.9731, - -1093.4564, - 388.91516, - -7471.7383, - -7998.2944, - -5328.388, - 218.0543, - -2731.6716, - -8056.8545, + -7172.837, + -3438.855, + -1093.312, + 388.8835, + -7471.467, + -7998.1206, + -5328.259, + 218.00089, + -2731.5745, + -8056.734, ], ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 From 7c1547185a6be3c47fa6ccfcc40ca5b56dbf9136 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 31 Aug 2023 18:45:49 +0200 Subject: [PATCH 149/181] added combined pipeline test --- .../wuerstchen/test_wuerstchen_combined.py | 225 ++++++++++++++++++ 1 file changed, 225 insertions(+) create mode 100644 tests/pipelines/wuerstchen/test_wuerstchen_combined.py diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py new file mode 100644 index 000000000000..5b62fc0d7d7d --- /dev/null +++ b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py @@ -0,0 +1,225 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import DDPMWuerstchenScheduler, WuerstchenPipeline +from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior +from diffusers.utils import torch_device +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class WuerstchenPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = WuerstchenPipeline + params = ["prompt"] + batch_params = ["prompt", "negative_prompt"] + required_optional_params = [ + "generator", + "height", + "width", + "latents", + "guidance_scale", + "negative_prompt", + "num_inference_steps", + "return_dict", + "guidance_scale", + "num_images_per_prompt", + "output_type", + "return_dict", + ] + test_xformers_attention = True + + @property + def text_embedder_hidden_size(self): + return 32 + + @property + def dummy_prior_prior(self): + torch.manual_seed(0) + + model_kwargs = {"c_in": 2, "c": 8, "depth": 2, "c_cond": 32, "c_r": 8, "nhead": 2} + model = WuerstchenPrior(**model_kwargs) + return model.eval() + + @property + def dummy_tokenizer(self): + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + return tokenizer + + @property + def dummy_prior_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=self.text_embedder_hidden_size, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModel(config).eval() + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + projection_dim=self.text_embedder_hidden_size, + hidden_size=self.text_embedder_hidden_size, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModel(config).eval() + + @property + def dummy_vqgan(self): + torch.manual_seed(0) + + model_kwargs = { + "bottleneck_blocks": 1, + "num_vq_embeddings": 2, + } + model = PaellaVQModel(**model_kwargs) + return model.eval() + + @property + def dummy_generator(self): + torch.manual_seed(0) + + model_kwargs = { + "c_cond": self.text_embedder_hidden_size, + "c_hidden": [320], + "nhead": [-1], + "blocks": [4], + "level_config": ["CT"], + "clip_embd": self.text_embedder_hidden_size, + "inject_effnet": [False], + } + + model = WuerstchenDiffNeXt(**model_kwargs) + return model.eval() + + def get_dummy_components(self): + prior_prior = self.dummy_prior_prior + prior_text_encoder = self.dummy_prior_text_encoder + + scheduler = DDPMWuerstchenScheduler() + tokenizer = self.dummy_tokenizer + + text_encoder = self.dummy_text_encoder + generator = self.dummy_generator + vqgan = self.dummy_vqgan + + components = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "generator": generator, + "vqgan": vqgan, + "scheduler": scheduler, + "prior_prior": prior_prior, + "prior_text_encoder": prior_text_encoder, + "prior_tokenizer": tokenizer, + "prior_scheduler": scheduler, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "horse", + "generator": generator, + "guidance_scale": 4.0, + "num_inference_steps": 2, + "output_type": "np", + } + return inputs + + def test_wuerstchen(self): + device = "cpu" + + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(device)) + image = output.images + + image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[-3:, -3:, -1] + + assert image.shape == (1, 512, 512, 3) + + expected_slice = np.array([1.0, 0.6543794, 0.20357049, 0.234462, 1.0, 0.0, 0.62647814, 1.0, 0.0]) + + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + assert ( + np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + + @require_torch_gpu + def test_offloads(self): + pipes = [] + components = self.get_dummy_components() + sd_pipe = self.pipeline_class(**components).to(torch_device) + pipes.append(sd_pipe) + + components = self.get_dummy_components() + sd_pipe = self.pipeline_class(**components) + sd_pipe.enable_sequential_cpu_offload() + pipes.append(sd_pipe) + + image_slices = [] + for pipe in pipes: + inputs = self.get_dummy_inputs(torch_device) + image = pipe(**inputs).images + + image_slices.append(image[0, -3:, -3:, -1].flatten()) + + assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=1e-2) + + @unittest.skip(reason="flakey and float16 requires CUDA") + def test_float16_inference(self): + super().test_float16_inference() From 1d6615e1c606f23346d280408a80d9170202a320 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 31 Aug 2023 20:54:30 +0200 Subject: [PATCH 150/181] faster test --- tests/pipelines/wuerstchen/test_wuerstchen_combined.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py index 5b62fc0d7d7d..abf1589b3f4c 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py @@ -42,6 +42,7 @@ class WuerstchenPipelineFastTests(PipelineTesterMixin, unittest.TestCase): "guidance_scale", "negative_prompt", "num_inference_steps", + "prior_num_inference_steps", "return_dict", "guidance_scale", "num_images_per_prompt", @@ -163,6 +164,7 @@ def get_dummy_inputs(self, device, seed=0): "generator": generator, "guidance_scale": 4.0, "num_inference_steps": 2, + "prior_num_inference_steps": 2, "output_type": "np", } return inputs @@ -187,7 +189,7 @@ def test_wuerstchen(self): assert image.shape == (1, 512, 512, 3) - expected_slice = np.array([1.0, 0.6543794, 0.20357049, 0.234462, 1.0, 0.0, 0.62647814, 1.0, 0.0]) + expected_slice = np.array([1.0, 0.35500407, 0.0, 0.0, 0.03041486, 0.0, 0.0, 0.0, 0.15140978]) assert ( np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 From 8fdce9dab54b74c32078161e0b589e8eed0c0850 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 1 Sep 2023 14:18:49 +0200 Subject: [PATCH 151/181] faster test --- .../wuerstchen/pipeline_wuerstchen.py | 33 +++++++++++++++++-- .../wuerstchen/pipeline_wuerstchen_prior.py | 31 ++++++++++++++++- .../wuerstchen/test_wuerstchen_combined.py | 12 +++++-- 3 files changed, 71 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index b010c7ebfbdc..372e466cb424 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -19,7 +19,7 @@ from transformers import CLIPTextModel, CLIPTokenizer from ...schedulers import DDPMWuerstchenScheduler -from ...utils import is_accelerate_available, logging, randn_tensor +from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .modeling_paella_vq_model import PaellaVQModel from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt @@ -111,10 +111,39 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") - for cpu_offloaded_model in [self.text_encoder, self.vqgan, self.generator]: + for cpu_offloaded_model in [self.text_encoder, self.generator, self.vqgan]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.generator]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.prior_hook = hook + + _, hook = cpu_offload_with_hook(self.vqgan, device, prev_module_hook=self.prior_hook) + + self.final_offload_hook = hook + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 7b9abc2e251e..8c27e21e0ce5 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -21,7 +21,7 @@ from transformers import CLIPTextModel, CLIPTokenizer from ...schedulers import DDPMWuerstchenScheduler -from ...utils import BaseOutput, is_accelerate_available, logging, randn_tensor +from ...utils import BaseOutput, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from .modeling_wuerstchen_prior import WuerstchenPrior @@ -110,6 +110,35 @@ def enable_sequential_cpu_offload(self, gpu_id=0): if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.prior_hook = hook + + _, hook = cpu_offload_with_hook(self.prior, device, prev_module_hook=self.prior_hook) + + self.final_offload_hook = hook + def prepare_latents(self, shape, dtype, device, generator, latents): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py index abf1589b3f4c..59b0d9c2773e 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py @@ -166,6 +166,8 @@ def get_dummy_inputs(self, device, seed=0): "num_inference_steps": 2, "prior_num_inference_steps": 2, "output_type": "np", + "height": 128, + "width": 128, } return inputs @@ -187,9 +189,9 @@ def test_wuerstchen(self): image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[-3:, -3:, -1] - assert image.shape == (1, 512, 512, 3) + assert image.shape == (1, 128, 128, 3) - expected_slice = np.array([1.0, 0.35500407, 0.0, 0.0, 0.03041486, 0.0, 0.0, 0.0, 0.15140978]) + expected_slice = np.array([0.7616304, 0.0, 1.0, 0.0, 1.0, 0.0, 0.05925313, 0.0, 0.951898]) assert ( np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -210,6 +212,11 @@ def test_offloads(self): sd_pipe.enable_sequential_cpu_offload() pipes.append(sd_pipe) + components = self.get_dummy_components() + sd_pipe = self.pipeline_class(**components) + sd_pipe.enable_model_cpu_offload() + pipes.append(sd_pipe) + image_slices = [] for pipe in pipes: inputs = self.get_dummy_inputs(torch_device) @@ -218,6 +225,7 @@ def test_offloads(self): image_slices.append(image[0, -3:, -3:, -1].flatten()) assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3 def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=1e-2) From f9a925973a70fd3e7b9048931e4fb37194fccb16 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 1 Sep 2023 19:05:11 +0200 Subject: [PATCH 152/181] Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py Co-authored-by: Patrick von Platen --- .../wuerstchen/pipeline_wuerstchen_combined.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index e187464208eb..204e0a999df7 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -48,23 +48,23 @@ class WuerstchenPipeline(DiffusionPipeline): library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: - tokenizer (:class:`~transformers.CLIPTokenizer`): + tokenizer (`CLIPTokenizer`): The decoder tokenizer to be used for text inputs. - text_encoder (:class:`~transformers.CLIPTextModel`): + text_encoder (`CLIPTextModel`): The decoder text encoder to be used for text inputs. - generator (:class:`~diffusions.models.WuerstchenDiffNeXt`): + generator (`WuerstchenDiffNeXt`): The generator model to be used for decoder image generation pipeline. - scheduler (:class:`~diffusions.schedulers.DDPMWuerstchenScheduler`): + scheduler (`DDPMWuerstchenScheduler`): The scheduler to be used for decoder image generation pipeline. - vqgan (:class:`~diffusions.pipelines.wuerstchen.modeling_paella_vq_model.PaellaVQModel`): + vqgan (`PaellaVQModel`): The VQGAN model to be used for decoder image generation pipeline. - prior_tokenizer (:class:`~transformers.CLIPTokenizer`): + prior_tokenizer (`CLIPTokenizer`): The prior tokenizer to be used for text inputs. - prior_text_encoder (:class:`~transformers.CLIPTextModel`): + prior_text_encoder (`CLIPTextModel`): The prior text encoder to be used for text inputs. - prior_prior (:class:`~diffusions.pipelines.wuerstchen.wuerstchen_prior.WuerstchenPrior`): + prior_prior (`WuerstchenPrior`): The prior model to be used for prior pipeline. - prior_scheduler (:class:`~diffusions.schedulers.DDPMWuerstchenScheduler`): + prior_scheduler (`DDPMWuerstchenScheduler`): The scheduler to be used for prior pipeline. """ From d45550b52eaf53766b42dc542dc0ccb45f6f5df5 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 4 Sep 2023 16:25:37 +0200 Subject: [PATCH 153/181] fix issues from review --- .../pipelines/wuerstchen/pipeline_wuerstchen_combined.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index 204e0a999df7..220fc4176ac5 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -135,7 +135,6 @@ def enable_sequential_cpu_offload(self, gpu_id=0): def progress_bar(self, iterable=None, total=None): self.prior_pipe.progress_bar(iterable=iterable, total=total) self.decoder_pipe.progress_bar(iterable=iterable, total=total) - self.decoder_pipe.enable_model_cpu_offload() def set_progress_bar_config(self, **kwargs): self.prior_pipe.set_progress_bar_config(**kwargs) @@ -185,9 +184,11 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - prior_num_inference_steps (`int`, *optional*, defaults to 30): + prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 30): The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. + expense of slower inference. This pipeline takes an optional dictionary of the form for example + `{2 / 3: 20, 0.0: 10}` 20 steps for the first 1/3 of denoising and 10 steps for the last 2/3 of the + denoising process. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -217,6 +218,8 @@ def __call__( Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` + [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. """ prior_outputs = self.prior_pipe( prompt=prompt, From 6f978ed55a5552c8e3e780b373caeed8758dbdb4 Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Mon, 4 Sep 2023 13:58:00 -0400 Subject: [PATCH 154/181] update wuerstchen.md + change generator name --- docs/source/en/api/pipelines/wuerstchen.md | 27 ++++++++++++--- .../wuerstchen/pipeline_wuerstchen.py | 21 ++++++------ .../pipeline_wuerstchen_combined.py | 34 +++++++++---------- 3 files changed, 50 insertions(+), 32 deletions(-) diff --git a/docs/source/en/api/pipelines/wuerstchen.md b/docs/source/en/api/pipelines/wuerstchen.md index 02a1c8284038..7fd986205a08 100644 --- a/docs/source/en/api/pipelines/wuerstchen.md +++ b/docs/source/en/api/pipelines/wuerstchen.md @@ -1,5 +1,7 @@ # Würstchen + + [Würstchen: Efficient Pretraining of Text-to-Image Models](https://huggingface.co/papers/2306.00637) is by Pablo Pernias, Dominic Rampas, and Marc Aubreville. The abstract from the paper is: @@ -15,6 +17,16 @@ After the initial paper release, we have improved numerous things in the archite - Multi Aspect Resolution Sampling - Better quality +We are releasing 3 checkpoints for the text-conditional image generation model (Stage C). Those are: +- v2-base +- v2-aesthetic +- v2-interpolated (50% interpolation between v2-base and v2-aesthetic) + +We recommend to use v2-interpolated, as it has a nice touch of both photorealism and aesthetic. Use v2-base for finetunings as it does not have a style bias and use v2-aesthetic for very artistic generations. +A comparison can be seen here: + + + ## Text-to-Image Generation For the sake of usability Würstchen can be used with a single pipeline. This pipeline is called `WuerstchenPipeline` and can be used as follows: @@ -32,14 +44,15 @@ pipeline = WuerstchenPipeline.from_pretrained( ).to(device) caption = "A captivating artwork of a mysterious stone golem" -negative_prompt = "bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated" +negative_prompt = "" output = pipeline( prompt=caption, height=1024, width=1024, negative_prompt=negative_prompt, - guidance_scale=8.0, + prior_guidance_scale=4.0, + decoder_guidance_scale=0.0, num_images_per_prompt=num_images_per_prompt, output_type="pil", ).images @@ -63,14 +76,14 @@ decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained( ).to(device) caption = "A captivating artwork of a mysterious stone golem" -negative_prompt = "bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated" +negative_prompt = "" prior_output = prior_pipeline( prompt=caption, height=1024, width=1024, negative_prompt=negative_prompt, - guidance_scale=8.0, + guidance_scale=4.0, num_images_per_prompt=num_images_per_prompt, ) decoder_output = decoder_pipeline( @@ -81,7 +94,13 @@ decoder_output = decoder_pipeline( guidance_scale=0.0, output_type="pil", ).images +``` +## Speed-Up Inference +You can make use of ``torch.compile`` function and gain a speed-up of about 2-3x: +```py +pipeline.prior = torch.compile(pipeline.prior, mode="reduce-overhead", fullgraph=True) +pipeline.decoder = torch.compile(pipeline.decoder, mode="reduce-overhead", fullgraph=True) ``` The original codebase, as well as experimental ideas, can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen). diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 372e466cb424..d92d40d5ad75 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -59,8 +59,8 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): The CLIP tokenizer. text_encoder (`CLIPTextModel`): The CLIP text encoder. - generator ([`WuerstchenDiffNeXt`]): - The WuerstchenDiffNeXt unet generator. + decoder ([`WuerstchenDiffNeXt`]): + The WuerstchenDiffNeXt unet decoder. vqgan ([`PaellaVQModel`]): The VQGAN model. scheduler ([`DDPMWuerstchenScheduler`]): @@ -73,7 +73,7 @@ def __init__( self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, - generator: WuerstchenDiffNeXt, + decoder: WuerstchenDiffNeXt, scheduler: DDPMWuerstchenScheduler, vqgan: PaellaVQModel, latent_dim_scale: float = 10.67, @@ -82,7 +82,7 @@ def __init__( self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, - generator=generator, + decoder=decoder, scheduler=scheduler, vqgan=vqgan, ) @@ -101,7 +101,7 @@ def prepare_latents(self, shape, dtype, device, generator, latents): def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, text_encoder, - generator, and vqgan have their state dicts saved to CPU and then are moved to a `torch.device('meta') and + decoder, and vqgan have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): @@ -111,7 +111,7 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") - for cpu_offloaded_model in [self.text_encoder, self.generator, self.vqgan]: + for cpu_offloaded_model in [self.text_encoder, self.decoder, self.vqgan]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @@ -134,7 +134,7 @@ def enable_model_cpu_offload(self, gpu_id=0): torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None - for cpu_offloaded_model in [self.text_encoder, self.generator]: + for cpu_offloaded_model in [self.text_encoder, self.decoder]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. @@ -271,7 +271,7 @@ def __call__( if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + # image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) if isinstance(num_inference_steps, int): num_inference_steps = {0.0: num_inference_steps} @@ -288,7 +288,7 @@ def __call__( image_embeds, text_encoder_hidden_states, do_classifier_free_guidance, device ) - dtype = self.generator.dtype + dtype = self.decoder.dtype latent_height = int(predicted_image_embeddings.size(2) * self.config.latent_dim_scale) latent_width = int(predicted_image_embeddings.size(3) * self.config.latent_dim_scale) latent_features_shape = (predicted_image_embeddings.size(0), 4, latent_height, latent_width) @@ -297,7 +297,6 @@ def __call__( timesteps = self.scheduler.timesteps latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents) - for t in self.progress_bar(timesteps[:-1]): ratio = t.expand(latents.size(0)).to(dtype) effnet = ( @@ -305,7 +304,7 @@ def __call__( if do_classifier_free_guidance else predicted_image_embeddings ) - predicted_latents = self.generator( + predicted_latents = self.decoder( torch.cat([latents] * 2) if do_classifier_free_guidance else latents, r=torch.cat([ratio] * 2) if do_classifier_free_guidance else ratio, effnet=effnet, diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index 204e0a999df7..c5330f967dae 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -52,8 +52,8 @@ class WuerstchenPipeline(DiffusionPipeline): The decoder tokenizer to be used for text inputs. text_encoder (`CLIPTextModel`): The decoder text encoder to be used for text inputs. - generator (`WuerstchenDiffNeXt`): - The generator model to be used for decoder image generation pipeline. + decoder (`WuerstchenDiffNeXt`): + The decoder model to be used for decoder image generation pipeline. scheduler (`DDPMWuerstchenScheduler`): The scheduler to be used for decoder image generation pipeline. vqgan (`PaellaVQModel`): @@ -62,7 +62,7 @@ class WuerstchenPipeline(DiffusionPipeline): The prior tokenizer to be used for text inputs. prior_text_encoder (`CLIPTextModel`): The prior text encoder to be used for text inputs. - prior_prior (`WuerstchenPrior`): + prior (`WuerstchenPrior`): The prior model to be used for prior pipeline. prior_scheduler (`DDPMWuerstchenScheduler`): The scheduler to be used for prior pipeline. @@ -74,12 +74,12 @@ def __init__( self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, - generator: WuerstchenDiffNeXt, + decoder: WuerstchenDiffNeXt, scheduler: DDPMWuerstchenScheduler, vqgan: PaellaVQModel, prior_tokenizer: CLIPTokenizer, prior_text_encoder: CLIPTextModel, - prior_prior: WuerstchenPrior, + prior: WuerstchenPrior, prior_scheduler: DDPMWuerstchenScheduler, ): super().__init__() @@ -87,16 +87,16 @@ def __init__( self.register_modules( text_encoder=text_encoder, tokenizer=tokenizer, - generator=generator, + decoder=decoder, scheduler=scheduler, vqgan=vqgan, - prior_prior=prior_prior, + prior=prior, prior_text_encoder=prior_text_encoder, prior_tokenizer=prior_tokenizer, prior_scheduler=prior_scheduler, ) self.prior_pipe = WuerstchenPriorPipeline( - prior=prior_prior, + prior=prior, text_encoder=prior_text_encoder, tokenizer=prior_tokenizer, scheduler=prior_scheduler, @@ -104,7 +104,7 @@ def __init__( self.decoder_pipe = WuerstchenDecoderPipeline( text_encoder=text_encoder, tokenizer=tokenizer, - generator=generator, + decoder=decoder, scheduler=scheduler, vqgan=vqgan, ) @@ -148,7 +148,7 @@ def __call__( prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, num_inference_steps: int = 12, - guidance_scale: float = 4.0, + decoder_guidance_scale: float = 4.0, num_images_per_prompt: int = 1, height: int = 512, width: int = 512, @@ -169,7 +169,7 @@ def __call__( The prompt or prompts to guide the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + if `decoder_guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 12): @@ -181,17 +181,17 @@ def __call__( The width in pixels of the generated image. prior_guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + `prior_guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. prior_num_inference_steps (`int`, *optional*, defaults to 30): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 4.0): + decoder_guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -243,7 +243,7 @@ def __call__( image_embeds=image_embeds, num_inference_steps=num_inference_steps, generator=generator, - guidance_scale=guidance_scale, + guidance_scale=decoder_guidance_scale, output_type=output_type, return_dict=return_dict, ) From 81f64de171ad1a1f1731c8094ed6b661e9ea93cf Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Mon, 4 Sep 2023 21:38:56 -0400 Subject: [PATCH 155/181] resolve issues --- docs/source/en/api/pipelines/wuerstchen.md | 13 +- scripts/convert_wuerstchen.py | 2 +- .../wuerstchen/pipeline_wuerstchen.py | 139 +++++++++--------- .../pipeline_wuerstchen_combined.py | 20 +-- .../wuerstchen/pipeline_wuerstchen_prior.py | 87 +++++++---- 5 files changed, 151 insertions(+), 110 deletions(-) diff --git a/docs/source/en/api/pipelines/wuerstchen.md b/docs/source/en/api/pipelines/wuerstchen.md index 7fd986205a08..8a272a77fe8a 100644 --- a/docs/source/en/api/pipelines/wuerstchen.md +++ b/docs/source/en/api/pipelines/wuerstchen.md @@ -98,11 +98,22 @@ decoder_output = decoder_pipeline( ## Speed-Up Inference You can make use of ``torch.compile`` function and gain a speed-up of about 2-3x: -```py + +```python pipeline.prior = torch.compile(pipeline.prior, mode="reduce-overhead", fullgraph=True) pipeline.decoder = torch.compile(pipeline.decoder, mode="reduce-overhead", fullgraph=True) ``` +## Limitations +- Due to the high compression employed by Würstchen, generations can lack a good amount +of detail. To our human eye, this is especially noticeable in faces, hands etc. +- **Images can only be generated in 128-pixel steps**, e.g. the next higher resolution +after 1024x1024 is 1152x1152 +- The model lacks the ability to render correct text in images +- The model often does not achieve photorealism +- Difficult compositional prompts are hard for the model + + The original codebase, as well as experimental ideas, can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen). ## WuerschenPipeline diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index fc23fc4f2250..6baa1864e57f 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -1,4 +1,4 @@ -# Run inside root directory of official source code +# Run inside root directory of official source code: https://github.com/dome272/wuerstchen/ import os import torch diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index d92d40d5ad75..622ed840dc18 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -19,7 +19,7 @@ from transformers import CLIPTextModel, CLIPTokenizer from ...schedulers import DDPMWuerstchenScheduler -from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor +from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .modeling_paella_vq_model import PaellaVQModel from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt @@ -42,7 +42,7 @@ >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" >>> prior_output = pipe(prompt) - >>> images = gen_pipe(prior_output.image_embeds, prompt=prompt) + >>> images = gen_pipe(prior_output.image_embeddings, prompt=prompt) ``` """ @@ -66,7 +66,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): scheduler ([`DDPMWuerstchenScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. latent_dim_scale (float, `optional`, defaults to 10.67): - The scale of the latent dimension. This is used to determine the size of the latent space. + Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are + height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and width=int(24*10.67)=256 in order + to match the training conditions. """ def __init__( @@ -89,6 +91,9 @@ def __init__( self.register_to_config(latent_dim_scale=latent_dim_scale) def prepare_latents(self, shape, dtype, device, generator, latents): + """ + Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + """ if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: @@ -98,23 +103,6 @@ def prepare_latents(self, shape, dtype, device, generator, latents): return latents - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, text_encoder, - decoder, and vqgan have their state dicts saved to CPU and then are moved to a `torch.device('meta') and - loaded to GPU only when their specific submodule has its `forward` method called. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - for cpu_offloaded_model in [self.text_encoder, self.decoder, self.vqgan]: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) - def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared @@ -144,7 +132,17 @@ def enable_model_cpu_offload(self, gpu_id=0): self.final_offload_hook = hook - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None): + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + """ + Copied and adjusted from diffusers.pipelines.kandinsky.pipeline_kandinsky._encode_prompt + """ batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( @@ -221,41 +219,45 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr return text_encoder_hidden_states def check_inputs( - self, predicted_image_embeddings, text_encoder_hidden_states, do_classifier_free_guidance, device + self, + image_embeddings, + prompt, + num_inference_steps, + do_classifier_free_guidance, + device, + dtype, ): - if not isinstance(text_encoder_hidden_states, torch.Tensor): - raise TypeError( - f"'text_encoder_hidden_states' must be of type 'torch.Tensor', but got {type(predicted_image_embeddings)}." - ) - if isinstance(predicted_image_embeddings, np.ndarray): - predicted_image_embeddings = torch.Tensor(predicted_image_embeddings, device=device).to( - dtype=text_encoder_hidden_states.dtype + if not isinstance(prompt, list): + if isinstance(prompt, str): + prompt = [prompt] + else: + raise TypeError( + f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}." + ) + if isinstance(image_embeddings, np.ndarray): + image_embeddings = torch.Tensor(image_embeddings, device=device).to( + dtype=dtype ) - if not isinstance(predicted_image_embeddings, torch.Tensor): + if not isinstance(image_embeddings, torch.Tensor): raise TypeError( - f"'predicted_image_embeddings' must be of type 'torch.Tensor' or 'np.array', but got {type(predicted_image_embeddings)}." + f"'image_embeddings' must be of type 'torch.Tensor' or 'np.array', but got {type(image_embeddings)}." ) + + if isinstance(num_inference_steps, int): + num_inference_steps = {0.0: num_inference_steps} - if do_classifier_free_guidance: - if predicted_image_embeddings.size(0) != text_encoder_hidden_states.size(0) // 2: - raise ValueError( - f"'text_encoder_hidden_states' must be double the size of 'predicted_image_embeddings' in the first dimension, but {predicted_image_embeddings.size(0)} != {text_encoder_hidden_states.size(0)}." - ) - - # if predicted_image_embeddings.size(0) * 2 == text_encoder_hidden_states.size(0): - text_encoder_hidden_states = text_encoder_hidden_states.chunk(2)[0] - - if predicted_image_embeddings.size(0) != text_encoder_hidden_states.size(0): - raise ValueError( - f"'text_encoder_hidden_states' must be the size of 'predicted_image_embeddings' in the first dimension, but {predicted_image_embeddings.size(0)} != {text_encoder_hidden_states.size(0)}." - ) + if not isinstance(num_inference_steps, dict): + raise TypeError( + f"'num_inference_steps' must be of type 'int' or 'dict', but got {type(num_inference_steps)}." + ) - return predicted_image_embeddings, text_encoder_hidden_states + return image_embeddings, prompt, num_inference_steps @torch.no_grad() + # @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], + image_embeddings: Union[torch.FloatTensor, List[torch.FloatTensor]], prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, num_inference_steps: Union[Dict[float, int], int] = 12, @@ -266,44 +268,43 @@ def __call__( output_type: Optional[str] = "pil", return_dict: bool = True, ): + + # 0. Define commonly used variables device = self._execution_device + dtype = self.decoder.dtype do_classifier_free_guidance = guidance_scale > 1.0 - if isinstance(image_embeds, list): - image_embeds = torch.cat(image_embeds, dim=0) - # image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - - if isinstance(num_inference_steps, int): - num_inference_steps = {0.0: num_inference_steps} - - if isinstance(prompt, str): - prompt = [prompt] - elif not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + # 1. Check inputs. Raise error if not correct + image_embeddings, prompt, num_inference_steps = self.check_inputs( + image_embeddings, prompt, num_inference_steps, do_classifier_free_guidance, device, dtype + ) + # 2. Encode caption text_encoder_hidden_states = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) - predicted_image_embeddings, text_encoder_hidden_states = self.check_inputs( - image_embeds, text_encoder_hidden_states, do_classifier_free_guidance, device - ) - dtype = self.decoder.dtype - latent_height = int(predicted_image_embeddings.size(2) * self.config.latent_dim_scale) - latent_width = int(predicted_image_embeddings.size(3) * self.config.latent_dim_scale) - latent_features_shape = (predicted_image_embeddings.size(0), 4, latent_height, latent_width) + # 3. Determine latent shape of latents + latent_height = int(image_embeddings.size(2) * self.config.latent_dim_scale) + latent_width = int(image_embeddings.size(3) * self.config.latent_dim_scale) + latent_features_shape = (image_embeddings.size(0), 4, latent_height, latent_width) + # 4. Prepare and set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps + # 5. Prepare latents latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents) + + # 6. Run denoising loop for t in self.progress_bar(timesteps[:-1]): ratio = t.expand(latents.size(0)).to(dtype) effnet = ( - torch.cat([predicted_image_embeddings, torch.zeros_like(predicted_image_embeddings)]) + torch.cat([image_embeddings, torch.zeros_like(image_embeddings)]) if do_classifier_free_guidance - else predicted_image_embeddings + else image_embeddings ) + # 7. Denoise latents predicted_latents = self.decoder( torch.cat([latents] * 2) if do_classifier_free_guidance else latents, r=torch.cat([ratio] * 2) if do_classifier_free_guidance else ratio, @@ -311,10 +312,12 @@ def __call__( clip=text_encoder_hidden_states, ) + # 8. Check for classifier free guidance and apply it if do_classifier_free_guidance: predicted_latents_text, predicted_latents_uncond = predicted_latents.chunk(2) predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, guidance_scale) + # 9. Renoise latents to next timestep latents = self.scheduler.step( model_output=predicted_latents, timestep=ratio, @@ -322,10 +325,10 @@ def __call__( generator=generator, ).prev_sample - # scale and decode the image latents with vq-vae + # 10. Scale and decode the image latents with vq-vae latents = self.vqgan.config.scale_factor * latents images = self.vqgan.decode(latents).sample.clamp(0, 1) - + if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `np` and `pil` are supported not output_type={output_type}") diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index 370ffdd4b66f..8aa9856dd7e4 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -146,13 +146,13 @@ def __call__( self, prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, - num_inference_steps: int = 12, decoder_guidance_scale: float = 4.0, num_images_per_prompt: int = 1, height: int = 512, width: int = 512, prior_guidance_scale: float = 4.0, prior_num_inference_steps: Union[int, Dict[float, int]] = {2 / 3: 20, 0.0: 10}, + decoder_num_inference_steps: Union[int, Dict[float, int]] = {0.0: 12}, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", @@ -171,9 +171,6 @@ def __call__( if `decoder_guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - num_inference_steps (`int`, *optional*, defaults to 12): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): @@ -186,9 +183,14 @@ def __call__( usually at the expense of lower image quality. prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 30): The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This pipeline takes an optional dictionary of the form for example - `{2 / 3: 20, 0.0: 10}` 20 steps for the first 1/3 of denoising and 10 steps for the last 2/3 of the - denoising process. + expense of slower inference. This pipeline takes an optional dictionary of the form {end_1: steps_1, end_2: steps_2, ..., end_n: steps_n}. + For example `{2 / 3: 20, 0.0: 10}` means from 100% noise to 66.6% noise we use 20 denoising steps and from 66.6% + to 0% we use 10 denoising steps. + decoder_num_inference_steps (`int`, *optional*, defaults to 12): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This pipeline takes an optional dictionary of the form {end_1: steps_1, end_2: steps_2, ..., end_n: steps_n}. + For example `{2 / 3: 20, 0.0: 10}` means from 100% noise to 66.6% noise we use 20 denoising steps and from 66.6% + to 0% we use 10 denoising steps. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -218,7 +220,7 @@ def __call__( Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` - [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, otherwise a + [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ prior_outputs = self.prior_pipe( @@ -244,7 +246,7 @@ def __call__( outputs = self.decoder_pipe( prompt=prompt, image_embeds=image_embeds, - num_inference_steps=num_inference_steps, + num_inference_steps=decoder_num_inference_steps, generator=generator, guidance_scale=decoder_guidance_scale, output_type=output_type, diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 8c27e21e0ce5..b48844270209 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -21,7 +21,7 @@ from transformers import CLIPTextModel, CLIPTokenizer from ...schedulers import DDPMWuerstchenScheduler -from ...utils import BaseOutput, is_accelerate_available, is_accelerate_version, logging, randn_tensor +from ...utils import BaseOutput, is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from .modeling_wuerstchen_prior import WuerstchenPrior @@ -83,6 +83,9 @@ def __init__( text_encoder: CLIPTextModel, prior: WuerstchenPrior, scheduler: DDPMWuerstchenScheduler, + latent_mean: float = 42.0, + latent_std: float = 1.0, + resolution_multiple: float = 42.67, ) -> None: super().__init__() self.register_modules( @@ -91,24 +94,7 @@ def __init__( prior=prior, scheduler=scheduler, ) - self.register_to_config() - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the text_encoder - and the prior have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to - GPU only when their specific submodule has its `forward` method called. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - for cpu_offloaded_model in [self.text_encoder, self.prior]: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) + self.register_to_config(latent_mean=latent_mean, latent_std=latent_std, resolution_multiple=resolution_multiple) def enable_model_cpu_offload(self, gpu_id=0): r""" @@ -140,6 +126,9 @@ def enable_model_cpu_offload(self, gpu_id=0): self.final_offload_hook = hook def prepare_latents(self, shape, dtype, device, generator, latents): + """ + Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + """ if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: @@ -157,6 +146,9 @@ def _encode_prompt( do_classifier_free_guidance, negative_prompt=None, ): + """ + Copied and adjusted from diffusers.pipelines.kandinsky.pipeline_kandinsky._encode_prompt + """ batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( @@ -231,8 +223,35 @@ def _encode_prompt( text_encoder_hidden_states = torch.cat([text_encoder_hidden_states, uncond_text_encoder_hidden_states]) return text_encoder_hidden_states + + def check_inputs( + self, + prompt, + num_inference_steps, + batch_size, + ): + if not isinstance(prompt, list): + if isinstance(prompt, str): + prompt = [prompt] + else: + raise TypeError( + f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}." + ) + + if isinstance(num_inference_steps, int): + num_inference_steps = {0.0: num_inference_steps} + + if not isinstance(num_inference_steps, dict): + raise TypeError( + f"'num_inference_steps' must be of type 'int' or 'dict', but got {type(num_inference_steps)}." + ) + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + return prompt, num_inference_steps, batch_size @torch.no_grad() + # @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, @@ -247,47 +266,53 @@ def __call__( output_type: Optional[str] = "pt", return_dict: bool = True, ): + # 0. Define commonly used variables device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 - if isinstance(num_inference_steps, int): - num_inference_steps = {0.0: num_inference_steps} - - if isinstance(prompt, str): - prompt = [prompt] - elif not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + # 1. Check inputs. Raise error if not correct + prompt, num_inference_steps, batch_size = self.check_inputs( + prompt, num_inference_steps, batch_size + ) - batch_size = len(prompt) if isinstance(prompt, list) else 1 + # 2. Encode caption text_encoder_hidden_states = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) + # 3. Determine latent shape of image embeddings dtype = text_encoder_hidden_states.dtype - latent_height = ceil(height / 42.67) - latent_width = ceil(width / 42.67) + latent_height = ceil(height / self.resolution_multiple) + latent_width = ceil(width / self.resolution_multiple) num_channels = self.prior.config.c_in effnet_features_shape = (num_images_per_prompt * batch_size, num_channels, latent_height, latent_width) + # 4. Prepare and set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps + # 5. Prepare latents latents = self.prepare_latents(effnet_features_shape, dtype, device, generator, latents) + # 6. Run denoising loop for t in self.progress_bar(timesteps[:-1]): ratio = t.expand(latents.size(0)).to(dtype) + + # 7. Denoise image embeddings predicted_image_embedding = self.prior( torch.cat([latents] * 2) if do_classifier_free_guidance else latents, r=torch.cat([ratio] * 2) if do_classifier_free_guidance else ratio, c=text_encoder_hidden_states, ) + # 8. Check for classifier free guidance and apply it if do_classifier_free_guidance: predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) predicted_image_embedding = torch.lerp( predicted_image_embedding_uncond, predicted_image_embedding_text, guidance_scale ) + # 9. Renoise latents to next timestep latents = self.scheduler.step( model_output=predicted_image_embedding, timestep=ratio, @@ -295,8 +320,8 @@ def __call__( generator=generator, ).prev_sample - # normalize the latents - latents = latents * 42.0 - 1.0 + # 10. Denormalize the latents + latents = latents * self.latent_mean - self.latent_std if output_type == "np": latents = latents.cpu().numpy() From 35772f174053ea54ebfd1f0e917dfd5fd2dfb5cd Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 5 Sep 2023 21:08:52 +0200 Subject: [PATCH 156/181] fix copied from usage and add back batch_size --- .../wuerstchen/pipeline_wuerstchen.py | 66 ++++++++++--------- .../pipeline_wuerstchen_combined.py | 8 +-- .../wuerstchen/pipeline_wuerstchen_prior.py | 61 ++++++++++------- .../schedulers/scheduling_ddpm_wuerstchen.py | 3 + tests/pipelines/test_pipelines_common.py | 2 + .../wuerstchen/test_wuerstchen_decoder.py | 12 ++-- 6 files changed, 87 insertions(+), 65 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 622ed840dc18..ec8e6cffb7d3 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -90,10 +90,8 @@ def __init__( ) self.register_to_config(latent_dim_scale=latent_dim_scale) - def prepare_latents(self, shape, dtype, device, generator, latents): - """ - Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents - """ + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: @@ -101,6 +99,7 @@ def prepare_latents(self, shape, dtype, device, generator, latents): raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) + latents = latents * scheduler.init_noise_sigma return latents def enable_model_cpu_offload(self, gpu_id=0): @@ -133,16 +132,13 @@ def enable_model_cpu_offload(self, gpu_id=0): self.final_offload_hook = hook def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - ): - """ - Copied and adjusted from diffusers.pipelines.kandinsky.pipeline_kandinsky._encode_prompt - """ + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( @@ -219,30 +215,28 @@ def _encode_prompt( return text_encoder_hidden_states def check_inputs( - self, - image_embeddings, - prompt, - num_inference_steps, - do_classifier_free_guidance, - device, + self, + image_embeddings, + prompt, + num_inference_steps, + do_classifier_free_guidance, + device, dtype, ): if not isinstance(prompt, list): if isinstance(prompt, str): prompt = [prompt] else: - raise TypeError( - f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}." - ) + raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.") + if isinstance(image_embeddings, list): + image_embeddings = torch.cat(image_embeddings, dim=0) if isinstance(image_embeddings, np.ndarray): - image_embeddings = torch.Tensor(image_embeddings, device=device).to( - dtype=dtype - ) + image_embeddings = torch.Tensor(image_embeddings, device=device).to(dtype=dtype) if not isinstance(image_embeddings, torch.Tensor): raise TypeError( f"'image_embeddings' must be of type 'torch.Tensor' or 'np.array', but got {type(image_embeddings)}." ) - + if isinstance(num_inference_steps, int): num_inference_steps = {0.0: num_inference_steps} @@ -254,7 +248,7 @@ def check_inputs( return image_embeddings, prompt, num_inference_steps @torch.no_grad() - # @replace_example_docstring(EXAMPLE_DOC_STRING) + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image_embeddings: Union[torch.FloatTensor, List[torch.FloatTensor]], @@ -268,7 +262,17 @@ def __call__( output_type: Optional[str] = "pil", return_dict: bool = True, ): - + r""" + Function invoked when calling the pipeline for generation. + + Args: + + Examples: + + Returns: + + """ + # 0. Define commonly used variables device = self._execution_device dtype = self.decoder.dtype @@ -294,7 +298,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latents - latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents) + latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents, self.scheduler) # 6. Run denoising loop for t in self.progress_bar(timesteps[:-1]): @@ -328,7 +332,7 @@ def __call__( # 10. Scale and decode the image latents with vq-vae latents = self.vqgan.config.scale_factor * latents images = self.vqgan.decode(latents).sample.clamp(0, 1) - + if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `np` and `pil` are supported not output_type={output_type}") diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index 8aa9856dd7e4..f0298a60e016 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -183,13 +183,13 @@ def __call__( usually at the expense of lower image quality. prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 30): The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This pipeline takes an optional dictionary of the form {end_1: steps_1, end_2: steps_2, ..., end_n: steps_n}. - For example `{2 / 3: 20, 0.0: 10}` means from 100% noise to 66.6% noise we use 20 denoising steps and from 66.6% + expense of slower inference. This pipeline takes an optional dictionary of the form {end_1: steps_1, end_2: steps_2, ..., end_n: steps_n}. + For example `{2 / 3: 20, 0.0: 10}` means from 100% noise to 66.6% noise we use 20 denoising steps and from 66.6% to 0% we use 10 denoising steps. decoder_num_inference_steps (`int`, *optional*, defaults to 12): The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This pipeline takes an optional dictionary of the form {end_1: steps_1, end_2: steps_2, ..., end_n: steps_n}. - For example `{2 / 3: 20, 0.0: 10}` means from 100% noise to 66.6% noise we use 20 denoising steps and from 66.6% + expense of slower inference. This pipeline takes an optional dictionary of the form {end_1: steps_1, end_2: steps_2, ..., end_n: steps_n}. + For example `{2 / 3: 20, 0.0: 10}` means from 100% noise to 66.6% noise we use 20 denoising steps and from 66.6% to 0% we use 10 denoising steps. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index b48844270209..e53661b41683 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -21,7 +21,14 @@ from transformers import CLIPTextModel, CLIPTokenizer from ...schedulers import DDPMWuerstchenScheduler -from ...utils import BaseOutput, is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring +from ...utils import ( + BaseOutput, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) from ..pipeline_utils import DiffusionPipeline from .modeling_wuerstchen_prior import WuerstchenPrior @@ -94,7 +101,9 @@ def __init__( prior=prior, scheduler=scheduler, ) - self.register_to_config(latent_mean=latent_mean, latent_std=latent_std, resolution_multiple=resolution_multiple) + self.register_to_config( + latent_mean=latent_mean, latent_std=latent_std, resolution_multiple=resolution_multiple + ) def enable_model_cpu_offload(self, gpu_id=0): r""" @@ -125,10 +134,8 @@ def enable_model_cpu_offload(self, gpu_id=0): self.final_offload_hook = hook - def prepare_latents(self, shape, dtype, device, generator, latents): - """ - Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents - """ + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: @@ -136,6 +143,7 @@ def prepare_latents(self, shape, dtype, device, generator, latents): raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) + latents = latents * scheduler.init_noise_sigma return latents def _encode_prompt( @@ -146,9 +154,6 @@ def _encode_prompt( do_classifier_free_guidance, negative_prompt=None, ): - """ - Copied and adjusted from diffusers.pipelines.kandinsky.pipeline_kandinsky._encode_prompt - """ batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( @@ -223,21 +228,19 @@ def _encode_prompt( text_encoder_hidden_states = torch.cat([text_encoder_hidden_states, uncond_text_encoder_hidden_states]) return text_encoder_hidden_states - + def check_inputs( - self, - prompt, - num_inference_steps, + self, + prompt, + num_inference_steps, batch_size, ): if not isinstance(prompt, list): if isinstance(prompt, str): prompt = [prompt] else: - raise TypeError( - f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}." - ) - + raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.") + if isinstance(num_inference_steps, int): num_inference_steps = {0.0: num_inference_steps} @@ -245,13 +248,13 @@ def check_inputs( raise TypeError( f"'num_inference_steps' must be of type 'int' or 'dict', but got {type(num_inference_steps)}." ) - + batch_size = len(prompt) if isinstance(prompt, list) else 1 return prompt, num_inference_steps, batch_size @torch.no_grad() - # @replace_example_docstring(EXAMPLE_DOC_STRING) + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, @@ -266,14 +269,24 @@ def __call__( output_type: Optional[str] = "pt", return_dict: bool = True, ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + + Examples: + + Returns: + + """ + # 0. Define commonly used variables device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 + batch_size = len(prompt) if isinstance(prompt, list) else 1 - # 1. Check inputs. Raise error if not correct - prompt, num_inference_steps, batch_size = self.check_inputs( - prompt, num_inference_steps, batch_size - ) + # 1. Check inputs. Raise error if not correct + prompt, num_inference_steps, batch_size = self.check_inputs(prompt, num_inference_steps, batch_size) # 2. Encode caption text_encoder_hidden_states = self._encode_prompt( @@ -292,7 +305,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latents - latents = self.prepare_latents(effnet_features_shape, dtype, device, generator, latents) + latents = self.prepare_latents(effnet_features_shape, dtype, device, generator, latents, self.scheduler) # 6. Run denoising loop for t in self.progress_bar(timesteps[:-1]): diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py index 4bd4d51b0c66..ae1365c18a70 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py +++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py @@ -111,6 +111,9 @@ def __init__( self.s = torch.tensor([s]) self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + def _alpha_cumprod(self, t, device): if self.scaler > 1: t = 1 - (1 - t) ** self.scaler diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 319dcb5aab32..5d520beafebc 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -348,6 +348,8 @@ def test_pipeline_call_signature(self): if v.default != inspect._empty: optional_parameters.add(k) + breakpoint() + parameters = set(parameters.keys()) parameters.remove("self") parameters.discard("kwargs") # kwargs can be added if arguments of pipeline call function are deprecated diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py index bfa8f5b7d7f0..71443ab57b57 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py @@ -33,10 +33,10 @@ class WuerstchenDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = WuerstchenDecoderPipeline params = ["prompt"] - batch_params = ["image_embeds", "prompt", "negative_prompt"] + batch_params = ["image_embeddings", "prompt", "negative_prompt"] required_optional_params = [ "num_images_per_prompt", - "generator", + "decoder", "num_inference_steps", "latents", "negative_prompt", @@ -96,7 +96,7 @@ def dummy_vqgan(self): return model.eval() @property - def dummy_generator(self): + def dummy_decoder(self): torch.manual_seed(0) model_kwargs = { @@ -113,7 +113,7 @@ def dummy_generator(self): return model.eval() def get_dummy_components(self): - generator = self.dummy_generator + decoder = self.dummy_decoder text_encoder = self.dummy_text_encoder tokenizer = self.dummy_tokenizer vqgan = self.dummy_vqgan @@ -121,7 +121,7 @@ def get_dummy_components(self): scheduler = DDPMWuerstchenScheduler() components = { - "generator": generator, + "decoder": decoder, "vqgan": vqgan, "text_encoder": text_encoder, "tokenizer": tokenizer, @@ -136,7 +136,7 @@ def get_dummy_inputs(self, device, seed=0): else: generator = torch.Generator(device=device).manual_seed(seed) inputs = { - "image_embeds": torch.ones((1, 16, 10, 10), device=device), + "image_embeddings": torch.ones((1, 16, 10, 10), device=device), "prompt": "horse", "generator": generator, "guidance_scale": 1.0, From 21068df4faa41614718b9f91369d1096361be1da Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 5 Sep 2023 21:13:07 +0200 Subject: [PATCH 157/181] fix API --- scripts/convert_wuerstchen.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 6baa1864e57f..c03fef646969 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -55,8 +55,8 @@ state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights else: state_dict[key] = orig_state_dict[key] -generator = WuerstchenDiffNeXt() -generator.load_state_dict(state_dict) +deocder = WuerstchenDiffNeXt() +deocder.load_state_dict(state_dict) # Prior orig_state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device)["ema_state_dict"] @@ -94,7 +94,7 @@ prior_pipeline.save_pretrained("warp-diffusion/WuerstchenPriorPipeline") decoder_pipeline = WuerstchenDecoderPipeline( - text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, generator=generator, scheduler=scheduler + text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=deocder, scheduler=scheduler ) decoder_pipeline.save_pretrained("warp-diffusion/WuerstchenDecoderPipeline") From 7f39e0c0394f1088ec294950449aa7ac4cba1016 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 5 Sep 2023 22:36:36 +0200 Subject: [PATCH 158/181] fix arguments --- scripts/convert_wuerstchen.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index c03fef646969..49cb670fb107 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -103,13 +103,13 @@ # Decoder text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, - generator=generator, + decoder=deocder, scheduler=scheduler, vqgan=vqmodel, # Prior prior_tokenizer=tokenizer, prior_text_encoder=text_encoder, - prior_prior=prior_model, + prior=prior_model, prior_scheduler=scheduler, ) wuerstchen_pipeline.save_pretrained("warp-diffusion/WuerstchenPipeline") From 0b97829419e21a5f12b1edde58cafce53aa2fef8 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 5 Sep 2023 22:38:11 +0200 Subject: [PATCH 159/181] fix combined test --- tests/pipelines/wuerstchen/test_wuerstchen_combined.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py index 59b0d9c2773e..c34666063101 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py @@ -56,7 +56,7 @@ def text_embedder_hidden_size(self): return 32 @property - def dummy_prior_prior(self): + def dummy_prior(self): torch.manual_seed(0) model_kwargs = {"c_in": 2, "c": 8, "depth": 2, "c_cond": 32, "c_r": 8, "nhead": 2} @@ -130,7 +130,7 @@ def dummy_generator(self): return model.eval() def get_dummy_components(self): - prior_prior = self.dummy_prior_prior + prior = self.dummy_prior prior_text_encoder = self.dummy_prior_text_encoder scheduler = DDPMWuerstchenScheduler() @@ -146,7 +146,7 @@ def get_dummy_components(self): "generator": generator, "vqgan": vqgan, "scheduler": scheduler, - "prior_prior": prior_prior, + "prior": prior, "prior_text_encoder": prior_text_encoder, "prior_tokenizer": tokenizer, "prior_scheduler": scheduler, From b801a56aeb32c5f867e221609728a08f792e3cbd Mon Sep 17 00:00:00 2001 From: Dominic Rampas Date: Tue, 5 Sep 2023 20:34:12 -0400 Subject: [PATCH 160/181] Added timesteps argument + fixes --- docs/source/en/api/pipelines/wuerstchen.md | 4 +- .../wuerstchen/pipeline_wuerstchen.py | 89 ++++++++++++++---- .../pipeline_wuerstchen_combined.py | 18 ++-- .../wuerstchen/pipeline_wuerstchen_prior.py | 90 +++++++++++++++---- .../schedulers/scheduling_ddpm_wuerstchen.py | 19 ++-- 5 files changed, 163 insertions(+), 57 deletions(-) diff --git a/docs/source/en/api/pipelines/wuerstchen.md b/docs/source/en/api/pipelines/wuerstchen.md index 8a272a77fe8a..a93de4da1116 100644 --- a/docs/source/en/api/pipelines/wuerstchen.md +++ b/docs/source/en/api/pipelines/wuerstchen.md @@ -43,7 +43,7 @@ pipeline = WuerstchenPipeline.from_pretrained( "warp-diffusion/WuerstchenPipeline", torch_dtype=dtype ).to(device) -caption = "A captivating artwork of a mysterious stone golem" +caption = "Anthropomorphic cat dressed as a fire fighter" negative_prompt = "" output = pipeline( @@ -87,7 +87,7 @@ prior_output = prior_pipeline( num_images_per_prompt=num_images_per_prompt, ) decoder_output = decoder_pipeline( - predicted_image_embeddings=prior_output.image_embeds, + image_embeddings=prior_output.image_embeddings, prompt=caption, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index ec8e6cffb7d3..3632dbfe4ef4 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import numpy as np import torch @@ -218,6 +218,7 @@ def check_inputs( self, image_embeddings, prompt, + negative_prompt, num_inference_steps, do_classifier_free_guidance, device, @@ -228,6 +229,14 @@ def check_inputs( prompt = [prompt] else: raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.") + + if do_classifier_free_guidance: + if not isinstance(negative_prompt, list): + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + else: + raise TypeError(f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}.") + if isinstance(image_embeddings, list): image_embeddings = torch.cat(image_embeddings, dim=0) if isinstance(image_embeddings, np.ndarray): @@ -237,15 +246,11 @@ def check_inputs( f"'image_embeddings' must be of type 'torch.Tensor' or 'np.array', but got {type(image_embeddings)}." ) - if isinstance(num_inference_steps, int): - num_inference_steps = {0.0: num_inference_steps} - - if not isinstance(num_inference_steps, dict): - raise TypeError( - f"'num_inference_steps' must be of type 'int' or 'dict', but got {type(num_inference_steps)}." - ) + if not isinstance(num_inference_steps, int): + raise TypeError(f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\ + In Case you want to provide explicit timesteps, please use the 'timesteps' argument.") - return image_embeddings, prompt, num_inference_steps + return image_embeddings, prompt, negative_prompt, num_inference_steps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -253,34 +258,79 @@ def __call__( self, image_embeddings: Union[torch.FloatTensor, List[torch.FloatTensor]], prompt: Union[str, List[str]] = None, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_inference_steps: Union[Dict[float, int], int] = 12, + num_inference_steps: int = 12, + timesteps: Optional[List[float]] = None, guidance_scale: float = 0.0, + negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, return_dict: bool = True, ): - r""" + """ Function invoked when calling the pipeline for generation. Args: + image_embedding (`torch.FloatTensor` or `List[torch.FloatTensor]`): + Image Embeddings either extracted from an image or generated by a Prior Model. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + num_inference_steps (`int`, *optional*, defaults to 30): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `decoder_guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `decoder_guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: - + [`~pipelines.ImagePipelineOutput`] or `tuple` + [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated image embeddings. """ + # 0. Define commonly used variables device = self._execution_device dtype = self.decoder.dtype do_classifier_free_guidance = guidance_scale > 1.0 # 1. Check inputs. Raise error if not correct - image_embeddings, prompt, num_inference_steps = self.check_inputs( - image_embeddings, prompt, num_inference_steps, do_classifier_free_guidance, device, dtype + image_embeddings, prompt, negative_prompt, num_inference_steps = self.check_inputs( + image_embeddings, prompt, negative_prompt, num_inference_steps, do_classifier_free_guidance, device, dtype ) # 2. Encode caption @@ -294,8 +344,13 @@ def __call__( latent_features_shape = (image_embeddings.size(0), 4, latent_height, latent_width) # 4. Prepare and set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps # 5. Prepare latents latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents, self.scheduler) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index f0298a60e016..37a5a6a36bab 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -151,8 +151,10 @@ def __call__( height: int = 512, width: int = 512, prior_guidance_scale: float = 4.0, - prior_num_inference_steps: Union[int, Dict[float, int]] = {2 / 3: 20, 0.0: 10}, - decoder_num_inference_steps: Union[int, Dict[float, int]] = {0.0: 12}, + prior_num_inference_steps: int = 60, + decoder_num_inference_steps: int = 12, + prior_timesteps: Optional[List[float]] = None, + decoder_timesteps: Optional[List[float]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", @@ -230,25 +232,23 @@ def __call__( height=height, num_images_per_prompt=num_images_per_prompt, num_inference_steps=prior_num_inference_steps, + timesteps=prior_timesteps, generator=generator, latents=latents, guidance_scale=prior_guidance_scale, output_type="pt", return_dict=False, ) - image_embeds = prior_outputs[0] - - prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt - - if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0: - prompt = (image_embeds.shape[0] // len(prompt)) * prompt + image_embeddings = prior_outputs[0] outputs = self.decoder_pipe( prompt=prompt, - image_embeds=image_embeds, + image_embeddings=image_embeddings, num_inference_steps=decoder_num_inference_steps, + timesteps=decoder_timesteps, generator=generator, guidance_scale=decoder_guidance_scale, + num_images_per_prompt=num_images_per_prompt, output_type=output_type, return_dict=return_dict, ) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index e53661b41683..cf9de2f9870f 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -14,7 +14,7 @@ from dataclasses import dataclass from math import ceil -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import numpy as np import torch @@ -57,12 +57,12 @@ class WuerstchenPriorPipelineOutput(BaseOutput): Output class for WuerstchenPriorPipeline. Args: - image_embeds (`torch.FloatTensor` or `np.ndarray`) + image_embeddings (`torch.FloatTensor` or `np.ndarray`) Prior image embeddings for text prompt """ - image_embeds: Union[torch.FloatTensor, np.ndarray] + image_embeddings: Union[torch.FloatTensor, np.ndarray] class WuerstchenPriorPipeline(DiffusionPipeline): @@ -232,7 +232,9 @@ def _encode_prompt( def check_inputs( self, prompt, + negative_prompt, num_inference_steps, + do_classifier_free_guidance, batch_size, ): if not isinstance(prompt, list): @@ -240,18 +242,21 @@ def check_inputs( prompt = [prompt] else: raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.") + + if do_classifier_free_guidance: + if not isinstance(negative_prompt, list): + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + else: + raise TypeError(f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}.") - if isinstance(num_inference_steps, int): - num_inference_steps = {0.0: num_inference_steps} - - if not isinstance(num_inference_steps, dict): - raise TypeError( - f"'num_inference_steps' must be of type 'int' or 'dict', but got {type(num_inference_steps)}." - ) + if not isinstance(num_inference_steps, int): + raise TypeError(f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\ + In Case you want to provide explicit timesteps, please use the 'timesteps' argument.") batch_size = len(prompt) if isinstance(prompt, list) else 1 - return prompt, num_inference_steps, batch_size + return prompt, negative_prompt, num_inference_steps, batch_size @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -260,24 +265,70 @@ def __call__( prompt: Union[str, List[str]] = None, height: int = 1024, width: int = 1024, - num_inference_steps: Union[int, Dict[float, int]] = {2 / 3: 20, 0.0: 10}, + num_inference_steps: int = 30, + timesteps: List[float] = None, guidance_scale: float = 8.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pt", + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, return_dict: bool = True, ): - r""" + """ Function invoked when calling the pipeline for generation. Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 30): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `decoder_guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `decoder_guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: - + [`~pipelines.WuerstchenPriorPipelineOutput`] or `tuple` + [`~pipelines.WuerstchenPriorPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated image embeddings. """ # 0. Define commonly used variables @@ -286,7 +337,7 @@ def __call__( batch_size = len(prompt) if isinstance(prompt, list) else 1 # 1. Check inputs. Raise error if not correct - prompt, num_inference_steps, batch_size = self.check_inputs(prompt, num_inference_steps, batch_size) + prompt, negative_prompt, num_inference_steps, batch_size = self.check_inputs(prompt, negative_prompt, num_inference_steps, do_classifier_free_guidance, batch_size) # 2. Encode caption text_encoder_hidden_states = self._encode_prompt( @@ -301,8 +352,13 @@ def __call__( effnet_features_shape = (num_images_per_prompt * batch_size, num_channels, latent_height, latent_width) # 4. Prepare and set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps # 5. Prepare latents latents = self.prepare_latents(effnet_features_shape, dtype, device, generator, latents, self.scheduler) diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py index ae1365c18a70..a92c5d61ff0d 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py +++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py @@ -17,7 +17,7 @@ import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union, List import torch @@ -140,7 +140,8 @@ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = def set_timesteps( self, - num_inference_steps: Dict[float, int], + num_inference_steps: int = None, + timesteps: Optional[List[int]] = None, device: Union[str, torch.device] = None, ): """ @@ -153,16 +154,10 @@ def set_timesteps( device (`str` or `torch.device`, optional): the device to which the timesteps are moved to. {2 / 3: 20, 0.0: 10} """ - timesteps = None - t_start = 1.0 - for t_end, steps in num_inference_steps.items(): - steps = torch.linspace(t_start, t_end, steps + 1, device=device) - t_start = t_end - if timesteps is None: - timesteps = steps - else: - timesteps = torch.cat([timesteps, steps[1:]]) - + if timesteps is None: + timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device) + if not isinstance(timesteps, torch.Tensor): + timesteps = torch.Tensor(timesteps).to(device) self.timesteps = timesteps def step( From 1e9336c70b70305793878b0532f7e57c01f33f17 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 6 Sep 2023 11:26:15 +0200 Subject: [PATCH 161/181] Update tests/pipelines/test_pipelines_common.py Co-authored-by: Patrick von Platen --- tests/pipelines/test_pipelines_common.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 5d520beafebc..319dcb5aab32 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -348,8 +348,6 @@ def test_pipeline_call_signature(self): if v.default != inspect._empty: optional_parameters.add(k) - breakpoint() - parameters = set(parameters.keys()) parameters.remove("self") parameters.discard("kwargs") # kwargs can be added if arguments of pipeline call function are deprecated From 9ca78d9d5bad4f60c82a33254f4854b4cf7e7df1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 6 Sep 2023 11:26:21 +0200 Subject: [PATCH 162/181] Update tests/pipelines/wuerstchen/test_wuerstchen_prior.py Co-authored-by: Patrick von Platen --- tests/pipelines/wuerstchen/test_wuerstchen_prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py index ac70cc0bea98..672b5f54d9ca 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py @@ -189,6 +189,6 @@ def test_attention_slicing_forward_pass(self): test_mean_pixel_difference=test_mean_pixel_difference, ) - @unittest.skip(reason="flakey and float16 requires CUDA") + @unittest.skip(reason="flaky for now") def test_float16_inference(self): super().test_float16_inference() From 9d8ea0745bc398a5479979df600495434a490e32 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 6 Sep 2023 11:30:24 +0200 Subject: [PATCH 163/181] Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py Co-authored-by: Patrick von Platen --- .../pipelines/wuerstchen/pipeline_wuerstchen_combined.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index 37a5a6a36bab..4a20cf56be0d 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -193,6 +193,12 @@ def __call__( expense of slower inference. This pipeline takes an optional dictionary of the form {end_1: steps_1, end_2: steps_2, ..., end_n: steps_n}. For example `{2 / 3: 20, 0.0: 10}` means from 100% noise to 66.6% noise we use 20 denoising steps and from 66.6% to 0% we use 10 denoising steps. + prior_timesteps (`List[float]`, *optional*): + Custom timesteps to use for the denoising process for the prior. If not defined, equal spaced `prior_num_inference_steps` + timesteps are used. Must be in descending order. + decoder_timesteps (`List[float]`, *optional*): + Custom timesteps to use for the denoising process for the decoder. If not defined, equal spaced `decoder_num_inference_steps` + timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen From a09b4efdbb7a58cf30edffeda7e014aa8a457c43 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 6 Sep 2023 11:34:36 +0200 Subject: [PATCH 164/181] Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py Co-authored-by: Patrick von Platen --- .../pipelines/wuerstchen/pipeline_wuerstchen_combined.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index 4a20cf56be0d..321c739a3e07 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -190,7 +190,7 @@ def __call__( to 0% we use 10 denoising steps. decoder_num_inference_steps (`int`, *optional*, defaults to 12): The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This pipeline takes an optional dictionary of the form {end_1: steps_1, end_2: steps_2, ..., end_n: steps_n}. + expense of slower inference. For more specific timestep spacing, you can pass customized `decoder_timesteps` For example `{2 / 3: 20, 0.0: 10}` means from 100% noise to 66.6% noise we use 20 denoising steps and from 66.6% to 0% we use 10 denoising steps. prior_timesteps (`List[float]`, *optional*): From 1c568ce62826634f6514a9fef9ff1627087f3089 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 6 Sep 2023 11:34:44 +0200 Subject: [PATCH 165/181] Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py Co-authored-by: Patrick von Platen --- .../pipelines/wuerstchen/pipeline_wuerstchen_combined.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index 321c739a3e07..68ab526e058f 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -185,9 +185,7 @@ def __call__( usually at the expense of lower image quality. prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 30): The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This pipeline takes an optional dictionary of the form {end_1: steps_1, end_2: steps_2, ..., end_n: steps_n}. - For example `{2 / 3: 20, 0.0: 10}` means from 100% noise to 66.6% noise we use 20 denoising steps and from 66.6% - to 0% we use 10 denoising steps. + expense of slower inference. For more specific timestep spacing, you can pass customized `prior_timesteps` decoder_num_inference_steps (`int`, *optional*, defaults to 12): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. For more specific timestep spacing, you can pass customized `decoder_timesteps` From 500cb6e94b402383124d0c6d9a2fae0fd6d0866d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Sep 2023 11:46:06 +0200 Subject: [PATCH 166/181] Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py --- .../pipelines/wuerstchen/pipeline_wuerstchen_combined.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index 68ab526e058f..cb9413085a6c 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -189,8 +189,6 @@ def __call__( decoder_num_inference_steps (`int`, *optional*, defaults to 12): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. For more specific timestep spacing, you can pass customized `decoder_timesteps` - For example `{2 / 3: 20, 0.0: 10}` means from 100% noise to 66.6% noise we use 20 denoising steps and from 66.6% - to 0% we use 10 denoising steps. prior_timesteps (`List[float]`, *optional*): Custom timesteps to use for the denoising process for the prior. If not defined, equal spaced `prior_num_inference_steps` timesteps are used. Must be in descending order. From 2d222edf1f93a2bea2e70acb06ce7045763fad90 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Sep 2023 09:50:52 +0000 Subject: [PATCH 167/181] up --- .../pipelines/wuerstchen/pipeline_wuerstchen.py | 9 ++++----- .../wuerstchen/pipeline_wuerstchen_prior.py | 15 ++++++++------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 3632dbfe4ef4..3f9a4f92f84f 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -131,7 +131,7 @@ def enable_model_cpu_offload(self, gpu_id=0): self.final_offload_hook = hook - def _encode_prompt( + def encode_prompt( self, prompt, device, @@ -210,9 +210,7 @@ def _encode_prompt( # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_encoder_hidden_states = torch.cat([text_encoder_hidden_states, uncond_text_encoder_hidden_states]) - - return text_encoder_hidden_states + return text_encoder_hidden_states, uncond_text_encoder_hidden_states def check_inputs( self, @@ -334,9 +332,10 @@ def __call__( ) # 2. Encode caption - text_encoder_hidden_states = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) + text_encoder_hidden_states = torch.cat([prompt_embeds, negative_prompt_embeds]) # 3. Determine latent shape of latents latent_height = int(image_embeddings.size(2) * self.config.latent_dim_scale) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index cf9de2f9870f..3a0961bbd95a 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -222,12 +222,7 @@ def _encode_prompt( ) # done duplicates - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_encoder_hidden_states = torch.cat([text_encoder_hidden_states, uncond_text_encoder_hidden_states]) - - return text_encoder_hidden_states + return text_encoder_hidden_states, uncond_text_encoder_hidden_states def check_inputs( self, @@ -340,10 +335,16 @@ def __call__( prompt, negative_prompt, num_inference_steps, batch_size = self.check_inputs(prompt, negative_prompt, num_inference_steps, do_classifier_free_guidance, batch_size) # 2. Encode caption - text_encoder_hidden_states = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_encoder_hidden_states = torch.cat([prompt_embeds, negative_prompt_embeds]) + + # 3. Determine latent shape of image embeddings dtype = text_encoder_hidden_states.dtype latent_height = ceil(height / self.resolution_multiple) From d8b62f288f874311528f36379aee91b25ededd0c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Sep 2023 09:52:43 +0000 Subject: [PATCH 168/181] Fix more --- .../pipelines/wuerstchen/pipeline_wuerstchen.py | 12 ++---------- .../wuerstchen/pipeline_wuerstchen_prior.py | 12 ++---------- 2 files changed, 4 insertions(+), 20 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 3f9a4f92f84f..7cc8a8983d96 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -264,8 +264,6 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, return_dict: bool = True, ): """ @@ -280,8 +278,8 @@ def __call__( The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` - timesteps are used. Must be in descending order. + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -303,12 +301,6 @@ def __call__( output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 3a0961bbd95a..8ed4fdceb47c 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -268,8 +268,6 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pt", - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, return_dict: bool = True, ): """ @@ -286,8 +284,8 @@ def __call__( The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` - timesteps are used. Must be in descending order. + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -309,12 +307,6 @@ def __call__( output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. From d4751ab7d3bb86721d5782fe6767d6ac2f5ece57 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 6 Sep 2023 12:05:34 +0200 Subject: [PATCH 169/181] failing tests --- .../wuerstchen/pipeline_wuerstchen.py | 17 ++++++++------- .../pipeline_wuerstchen_combined.py | 6 +++--- .../wuerstchen/pipeline_wuerstchen_prior.py | 21 ++++++++++++------- .../schedulers/scheduling_ddpm_wuerstchen.py | 2 +- .../wuerstchen/test_wuerstchen_combined.py | 19 ++++++++--------- .../wuerstchen/test_wuerstchen_decoder.py | 1 - 6 files changed, 36 insertions(+), 30 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 7cc8a8983d96..f96c75dc927f 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, List, Optional, Union +from typing import List, Optional, Union import numpy as np import torch @@ -227,13 +227,15 @@ def check_inputs( prompt = [prompt] else: raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.") - + if do_classifier_free_guidance: if not isinstance(negative_prompt, list): if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] else: - raise TypeError(f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}.") + raise TypeError( + f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}." + ) if isinstance(image_embeddings, list): image_embeddings = torch.cat(image_embeddings, dim=0) @@ -245,8 +247,10 @@ def check_inputs( ) if not isinstance(num_inference_steps, int): - raise TypeError(f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\ - In Case you want to provide explicit timesteps, please use the 'timesteps' argument.") + raise TypeError( + f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\ + In Case you want to provide explicit timesteps, please use the 'timesteps' argument." + ) return image_embeddings, prompt, negative_prompt, num_inference_steps @@ -276,7 +280,7 @@ def __call__( The prompt or prompts to guide the image generation. num_inference_steps (`int`, *optional*, defaults to 30): The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. + expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. @@ -312,7 +316,6 @@ def __call__( `tuple`. When returning a tuple, the first element is a list with the generated image embeddings. """ - # 0. Define commonly used variables device = self._execution_device dtype = self.decoder.dtype diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index cb9413085a6c..ed417b87f88b 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, List, Optional, Union import torch from transformers import CLIPTextModel, CLIPTokenizer @@ -189,10 +189,10 @@ def __call__( decoder_num_inference_steps (`int`, *optional*, defaults to 12): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. For more specific timestep spacing, you can pass customized `decoder_timesteps` - prior_timesteps (`List[float]`, *optional*): + prior_timesteps (`List[float]`, *optional*): Custom timesteps to use for the denoising process for the prior. If not defined, equal spaced `prior_num_inference_steps` timesteps are used. Must be in descending order. - decoder_timesteps (`List[float]`, *optional*): + decoder_timesteps (`List[float]`, *optional*): Custom timesteps to use for the denoising process for the decoder. If not defined, equal spaced `decoder_num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 4.0): diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 8ed4fdceb47c..f296935a8e5f 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -14,7 +14,7 @@ from dataclasses import dataclass from math import ceil -from typing import Callable, Dict, List, Optional, Union +from typing import List, Optional, Union import numpy as np import torch @@ -237,17 +237,21 @@ def check_inputs( prompt = [prompt] else: raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.") - + if do_classifier_free_guidance: if not isinstance(negative_prompt, list): if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] else: - raise TypeError(f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}.") + raise TypeError( + f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}." + ) if not isinstance(num_inference_steps, int): - raise TypeError(f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\ - In Case you want to provide explicit timesteps, please use the 'timesteps' argument.") + raise TypeError( + f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\ + In Case you want to provide explicit timesteps, please use the 'timesteps' argument." + ) batch_size = len(prompt) if isinstance(prompt, list) else 1 @@ -282,7 +286,7 @@ def __call__( The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 30): The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. + expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. @@ -324,7 +328,9 @@ def __call__( batch_size = len(prompt) if isinstance(prompt, list) else 1 # 1. Check inputs. Raise error if not correct - prompt, negative_prompt, num_inference_steps, batch_size = self.check_inputs(prompt, negative_prompt, num_inference_steps, do_classifier_free_guidance, batch_size) + prompt, negative_prompt, num_inference_steps, batch_size = self.check_inputs( + prompt, negative_prompt, num_inference_steps, do_classifier_free_guidance, batch_size + ) # 2. Encode caption prompt_embeds, negative_prompt_embeds = self.encode_prompt( @@ -336,7 +342,6 @@ def __call__( # to avoid doing two forward passes text_encoder_hidden_states = torch.cat([prompt_embeds, negative_prompt_embeds]) - # 3. Determine latent shape of image embeddings dtype = text_encoder_hidden_states.dtype latent_height = ceil(height / self.resolution_multiple) diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py index a92c5d61ff0d..28311fc03301 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py +++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py @@ -17,7 +17,7 @@ import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union, List +from typing import List, Optional, Tuple, Union import torch diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py index c34666063101..f81c3db8d447 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py @@ -39,13 +39,12 @@ class WuerstchenPipelineFastTests(PipelineTesterMixin, unittest.TestCase): "height", "width", "latents", - "guidance_scale", + "decoder_guidance_scale", "negative_prompt", - "num_inference_steps", - "prior_num_inference_steps", + "decoder_num_inference_steps", "return_dict", - "guidance_scale", - "num_images_per_prompt", + "prior_guidance_scale", + "prior_num_inference_steps", "output_type", "return_dict", ] @@ -113,7 +112,7 @@ def dummy_vqgan(self): return model.eval() @property - def dummy_generator(self): + def dummy_decoder(self): torch.manual_seed(0) model_kwargs = { @@ -137,13 +136,13 @@ def get_dummy_components(self): tokenizer = self.dummy_tokenizer text_encoder = self.dummy_text_encoder - generator = self.dummy_generator + decoder = self.dummy_decoder vqgan = self.dummy_vqgan components = { "tokenizer": tokenizer, "text_encoder": text_encoder, - "generator": generator, + "decoder": decoder, "vqgan": vqgan, "scheduler": scheduler, "prior": prior, @@ -162,8 +161,8 @@ def get_dummy_inputs(self, device, seed=0): inputs = { "prompt": "horse", "generator": generator, - "guidance_scale": 4.0, - "num_inference_steps": 2, + "prior_guidance_scale": 4.0, + "decoder_num_inference_steps": 2, "prior_num_inference_steps": 2, "output_type": "np", "height": 128, diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py index 71443ab57b57..1d843a17f57b 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py @@ -36,7 +36,6 @@ class WuerstchenDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCase) batch_params = ["image_embeddings", "prompt", "negative_prompt"] required_optional_params = [ "num_images_per_prompt", - "decoder", "num_inference_steps", "latents", "negative_prompt", From 3b705e8f4d96d312989eb53ce5690f4091866fa1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Sep 2023 10:06:23 +0000 Subject: [PATCH 170/181] up --- .../wuerstchen/pipeline_wuerstchen.py | 33 +++++----- .../pipeline_wuerstchen_combined.py | 63 +++++++++---------- .../wuerstchen/pipeline_wuerstchen_prior.py | 37 ++++++----- .../schedulers/scheduling_ddpm_wuerstchen.py | 2 +- 4 files changed, 68 insertions(+), 67 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 7cc8a8983d96..48ca600ca234 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, List, Optional, Union +from typing import List, Optional, Union import numpy as np import torch @@ -67,8 +67,8 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): A scheduler to be used in combination with `prior` to generate image embedding. latent_dim_scale (float, `optional`, defaults to 10.67): Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are - height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and width=int(24*10.67)=256 in order - to match the training conditions. + height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and + width=int(24*10.67)=256 in order to match the training conditions. """ def __init__( @@ -227,13 +227,15 @@ def check_inputs( prompt = [prompt] else: raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.") - + if do_classifier_free_guidance: if not isinstance(negative_prompt, list): if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] else: - raise TypeError(f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}.") + raise TypeError( + f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}." + ) if isinstance(image_embeddings, list): image_embeddings = torch.cat(image_embeddings, dim=0) @@ -245,8 +247,10 @@ def check_inputs( ) if not isinstance(num_inference_steps, int): - raise TypeError(f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\ - In Case you want to provide explicit timesteps, please use the 'timesteps' argument.") + raise TypeError( + f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\ + In Case you want to provide explicit timesteps, please use the 'timesteps' argument." + ) return image_embeddings, prompt, negative_prompt, num_inference_steps @@ -276,16 +280,16 @@ def __call__( The prompt or prompts to guide the image generation. num_inference_steps (`int`, *optional*, defaults to 30): The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. + expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `decoder_guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting + `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely + linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `decoder_guidance_scale` is less than `1`). @@ -307,12 +311,11 @@ def __call__( Examples: Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple` - [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, otherwise a - `tuple`. When returning a tuple, the first element is a list with the generated image embeddings. + [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image + embeddings. """ - # 0. Define commonly used variables device = self._execution_device dtype = self.decoder.dtype diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index cb9413085a6c..ff39cb473f1e 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, List, Optional, Union import torch from transformers import CLIPTextModel, CLIPTokenizer @@ -79,7 +79,7 @@ def __init__( vqgan: PaellaVQModel, prior_tokenizer: CLIPTokenizer, prior_text_encoder: CLIPTextModel, - prior: WuerstchenPrior, + prior_prior: WuerstchenPrior, prior_scheduler: DDPMWuerstchenScheduler, ): super().__init__() @@ -90,13 +90,13 @@ def __init__( decoder=decoder, scheduler=scheduler, vqgan=vqgan, - prior=prior, + prior_prior=prior_prior, prior_text_encoder=prior_text_encoder, prior_tokenizer=prior_tokenizer, prior_scheduler=prior_scheduler, ) self.prior_pipe = WuerstchenPriorPipeline( - prior=prior, + prior=prior_prior, text_encoder=prior_text_encoder, tokenizer=prior_tokenizer, scheduler=prior_scheduler, @@ -146,20 +146,18 @@ def __call__( self, prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, - decoder_guidance_scale: float = 4.0, + guidance_scale: float = 4.0, num_images_per_prompt: int = 1, height: int = 512, width: int = 512, prior_guidance_scale: float = 4.0, prior_num_inference_steps: int = 60, - decoder_num_inference_steps: int = 12, + num_inference_steps: int = 12, prior_timesteps: Optional[List[float]] = None, - decoder_timesteps: Optional[List[float]] = None, + timesteps: Optional[List[float]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, return_dict: bool = True, ): """ @@ -170,7 +168,7 @@ def __call__( The prompt or prompts to guide the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `decoder_guidance_scale` is less than `1`). + if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. height (`int`, *optional*, defaults to 512): @@ -180,25 +178,27 @@ def __call__( prior_guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `prior_guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `prior_guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting + `prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked + to the text `prompt`, usually at the expense of lower image quality. prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 30): The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. For more specific timestep spacing, you can pass customized `prior_timesteps` - decoder_num_inference_steps (`int`, *optional*, defaults to 12): + expense of slower inference. For more specific timestep spacing, you can pass customized + `prior_timesteps` + num_inference_steps (`int`, *optional*, defaults to 12): The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. For more specific timestep spacing, you can pass customized `decoder_timesteps` - prior_timesteps (`List[float]`, *optional*): - Custom timesteps to use for the denoising process for the prior. If not defined, equal spaced `prior_num_inference_steps` - timesteps are used. Must be in descending order. - decoder_timesteps (`List[float]`, *optional*): - Custom timesteps to use for the denoising process for the decoder. If not defined, equal spaced `decoder_num_inference_steps` + expense of slower inference. For more specific timestep spacing, you can pass customized `timesteps` + prior_timesteps (`List[float]`, *optional*): + Custom timesteps to use for the denoising process for the prior. If not defined, equal spaced + `prior_num_inference_steps` timesteps are used. Must be in descending order. + timesteps (`List[float]`, *optional*): + Custom timesteps to use for the denoising process for the decoder. If not defined, equal spaced + `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `decoder_guidance_scale > + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -211,21 +211,14 @@ def __call__( output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple` - [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, otherwise a - `tuple`. When returning a tuple, the first element is a list with the generated images. + [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ prior_outputs = self.prior_pipe( prompt=prompt, @@ -246,10 +239,10 @@ def __call__( outputs = self.decoder_pipe( prompt=prompt, image_embeddings=image_embeddings, - num_inference_steps=decoder_num_inference_steps, - timesteps=decoder_timesteps, + num_inference_steps=num_inference_steps, + timesteps=timesteps, generator=generator, - guidance_scale=decoder_guidance_scale, + guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt, output_type=output_type, return_dict=return_dict, diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 8ed4fdceb47c..c65c2d4acd2a 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -14,7 +14,7 @@ from dataclasses import dataclass from math import ceil -from typing import Callable, Dict, List, Optional, Union +from typing import List, Optional, Union import numpy as np import torch @@ -146,7 +146,7 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): latents = latents * scheduler.init_noise_sigma return latents - def _encode_prompt( + def encode_prompt( self, prompt, device, @@ -237,17 +237,21 @@ def check_inputs( prompt = [prompt] else: raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.") - + if do_classifier_free_guidance: - if not isinstance(negative_prompt, list): + if negative_prompt is not None and not isinstance(negative_prompt, list): if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] else: - raise TypeError(f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}.") + raise TypeError( + f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}." + ) if not isinstance(num_inference_steps, int): - raise TypeError(f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\ - In Case you want to provide explicit timesteps, please use the 'timesteps' argument.") + raise TypeError( + f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\ + In Case you want to provide explicit timesteps, please use the 'timesteps' argument." + ) batch_size = len(prompt) if isinstance(prompt, list) else 1 @@ -282,16 +286,16 @@ def __call__( The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 30): The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. + expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `decoder_guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting + `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely + linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `decoder_guidance_scale` is less than `1`). @@ -313,9 +317,9 @@ def __call__( Examples: Returns: - [`~pipelines.WuerstchenPriorPipelineOutput`] or `tuple` - [`~pipelines.WuerstchenPriorPipelineOutput`] if `return_dict` is True, otherwise a - `tuple`. When returning a tuple, the first element is a list with the generated image embeddings. + [`~pipelines.WuerstchenPriorPipelineOutput`] or `tuple` [`~pipelines.WuerstchenPriorPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated image embeddings. """ # 0. Define commonly used variables @@ -324,7 +328,9 @@ def __call__( batch_size = len(prompt) if isinstance(prompt, list) else 1 # 1. Check inputs. Raise error if not correct - prompt, negative_prompt, num_inference_steps, batch_size = self.check_inputs(prompt, negative_prompt, num_inference_steps, do_classifier_free_guidance, batch_size) + prompt, negative_prompt, num_inference_steps, batch_size = self.check_inputs( + prompt, negative_prompt, num_inference_steps, do_classifier_free_guidance, batch_size + ) # 2. Encode caption prompt_embeds, negative_prompt_embeds = self.encode_prompt( @@ -336,7 +342,6 @@ def __call__( # to avoid doing two forward passes text_encoder_hidden_states = torch.cat([prompt_embeds, negative_prompt_embeds]) - # 3. Determine latent shape of image embeddings dtype = text_encoder_hidden_states.dtype latent_height = ceil(height / self.resolution_multiple) diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py index a92c5d61ff0d..28311fc03301 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py +++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py @@ -17,7 +17,7 @@ import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union, List +from typing import List, Optional, Tuple, Union import torch From 84908044f41cd3eef10bb313ab724eec4ac9f1e1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Sep 2023 10:14:52 +0000 Subject: [PATCH 171/181] up --- docs/source/en/api/pipelines/wuerstchen.md | 10 +++++----- scripts/convert_wuerstchen.py | 6 +++--- src/diffusers/__init__.py | 2 +- src/diffusers/pipelines/__init__.py | 2 +- src/diffusers/pipelines/auto_pipeline.py | 4 ++-- src/diffusers/pipelines/wuerstchen/__init__.py | 2 +- .../wuerstchen/pipeline_wuerstchen_combined.py | 6 +++--- .../utils/dummy_torch_and_transformers_objects.py | 2 +- tests/pipelines/wuerstchen/test_wuerstchen_combined.py | 6 +++--- 9 files changed, 20 insertions(+), 20 deletions(-) diff --git a/docs/source/en/api/pipelines/wuerstchen.md b/docs/source/en/api/pipelines/wuerstchen.md index a93de4da1116..ac4da9f29453 100644 --- a/docs/source/en/api/pipelines/wuerstchen.md +++ b/docs/source/en/api/pipelines/wuerstchen.md @@ -29,18 +29,18 @@ A comparison can be seen here: ## Text-to-Image Generation -For the sake of usability Würstchen can be used with a single pipeline. This pipeline is called `WuerstchenPipeline` and can be used as follows: +For the sake of usability Würstchen can be used with a single pipeline. This pipeline is called `WuerstchenCombinedPipeline` and can be used as follows: ```python import torch -from diffusers import WuerstchenPipeline +from diffusers import WuerstchenCombinedPipeline device = "cuda" dtype = torch.float16 num_images_per_prompt = 2 -pipeline = WuerstchenPipeline.from_pretrained( - "warp-diffusion/WuerstchenPipeline", torch_dtype=dtype +pipeline = WuerstchenCombinedPipeline.from_pretrained( + "warp-diffusion/WuerstchenCombinedPipeline", torch_dtype=dtype ).to(device) caption = "Anthropomorphic cat dressed as a fire fighter" @@ -118,7 +118,7 @@ The original codebase, as well as experimental ideas, can be found at [dome272/W ## WuerschenPipeline -[[autodoc]] WuerstchenPipeline +[[autodoc]] WuerstchenCombinedPipeline - all - __call__ diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 49cb670fb107..8961f6ed9136 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -8,7 +8,7 @@ from diffusers import ( DDPMWuerstchenScheduler, WuerstchenDecoderPipeline, - WuerstchenPipeline, + WuerstchenCombinedPipeline, WuerstchenPriorPipeline, ) from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior @@ -99,7 +99,7 @@ decoder_pipeline.save_pretrained("warp-diffusion/WuerstchenDecoderPipeline") # Wuerstchen pipeline -wuerstchen_pipeline = WuerstchenPipeline( +wuerstchen_pipeline = WuerstchenCombinedPipeline( # Decoder text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, @@ -112,4 +112,4 @@ prior=prior_model, prior_scheduler=scheduler, ) -wuerstchen_pipeline.save_pretrained("warp-diffusion/WuerstchenPipeline") +wuerstchen_pipeline.save_pretrained("warp-diffusion/WuerstchenCombinedPipeline") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 308e0a16b8f2..9aee4cea4f6f 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -218,7 +218,7 @@ VideoToVideoSDPipeline, VQDiffusionPipeline, WuerstchenDecoderPipeline, - WuerstchenPipeline, + WuerstchenCombinedPipeline, WuerstchenPriorPipeline, ) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3bf64940c8d4..50d7188ce8ba 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -132,7 +132,7 @@ VersatileDiffusionTextToImagePipeline, ) from .vq_diffusion import VQDiffusionPipeline - from .wuerstchen import WuerstchenDecoderPipeline, WuerstchenPipeline, WuerstchenPriorPipeline + from .wuerstchen import WuerstchenDecoderPipeline, WuerstchenCombinedPipeline, WuerstchenPriorPipeline try: diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 269d3cd297d8..3f7ff6fc5511 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -52,7 +52,7 @@ StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) -from .wuerstchen import WuerstchenDecoderPipeline, WuerstchenPipeline +from .wuerstchen import WuerstchenDecoderPipeline, WuerstchenCombinedPipeline AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( @@ -64,7 +64,7 @@ ("kandinsky22", KandinskyV22CombinedPipeline), ("stable-diffusion-controlnet", StableDiffusionControlNetPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline), - ("wuerstchen", WuerstchenPipeline), + ("wuerstchen", WuerstchenCombinedPipeline), ] ) diff --git a/src/diffusers/pipelines/wuerstchen/__init__.py b/src/diffusers/pipelines/wuerstchen/__init__.py index 998d48689994..a6f6321b048a 100644 --- a/src/diffusers/pipelines/wuerstchen/__init__.py +++ b/src/diffusers/pipelines/wuerstchen/__init__.py @@ -6,5 +6,5 @@ from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt from .modeling_wuerstchen_prior import WuerstchenPrior from .pipeline_wuerstchen import WuerstchenDecoderPipeline - from .pipeline_wuerstchen_combined import WuerstchenPipeline + from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index 5016db08f4f2..71aed32777b8 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -29,9 +29,9 @@ TEXT2IMAGE_EXAMPLE_DOC_STRING = """ Examples: ```py - >>> from diffusions import WuerstchenPipeline + >>> from diffusions import WuerstchenCombinedPipeline - >>> pipe = WuerstchenPipeline.from_pretrained("warp-diffusion/Wuerstchen", torch_dtype=torch.float16).to( + >>> pipe = WuerstchenCombinedPipeline.from_pretrained("warp-diffusion/Wuerstchen", torch_dtype=torch.float16).to( ... "cuda" ... ) >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" @@ -40,7 +40,7 @@ """ -class WuerstchenPipeline(DiffusionPipeline): +class WuerstchenCombinedPipeline(DiffusionPipeline): """ Combined Pipeline for text-to-image generation using Wuerstchen diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 3637b3fde439..fd1c897e10e9 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1277,7 +1277,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class WuerstchenPipeline(metaclass=DummyObject): +class WuerstchenCombinedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py index f81c3db8d447..4fb2af8e4da0 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py @@ -19,7 +19,7 @@ import torch from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer -from diffusers import DDPMWuerstchenScheduler, WuerstchenPipeline +from diffusers import DDPMWuerstchenScheduler, WuerstchenCombinedPipeline from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior from diffusers.utils import torch_device from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu @@ -30,8 +30,8 @@ enable_full_determinism() -class WuerstchenPipelineFastTests(PipelineTesterMixin, unittest.TestCase): - pipeline_class = WuerstchenPipeline +class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = WuerstchenCombinedPipeline params = ["prompt"] batch_params = ["prompt", "negative_prompt"] required_optional_params = [ From 88032f17051e8ed17fe1d7cd83b6f24531a6f271 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Sep 2023 10:28:47 +0000 Subject: [PATCH 172/181] correct naming --- docs/source/en/api/pipelines/wuerstchen.md | 4 ++-- scripts/convert_wuerstchen.py | 6 +++--- src/diffusers/__init__.py | 2 +- src/diffusers/pipelines/__init__.py | 2 +- src/diffusers/pipelines/auto_pipeline.py | 14 +------------- .../pipelines/wuerstchen/pipeline_wuerstchen.py | 4 ++-- .../wuerstchen/pipeline_wuerstchen_combined.py | 6 +++--- .../wuerstchen/pipeline_wuerstchen_prior.py | 8 ++++---- 8 files changed, 17 insertions(+), 29 deletions(-) diff --git a/docs/source/en/api/pipelines/wuerstchen.md b/docs/source/en/api/pipelines/wuerstchen.md index ac4da9f29453..cfa16847a4ff 100644 --- a/docs/source/en/api/pipelines/wuerstchen.md +++ b/docs/source/en/api/pipelines/wuerstchen.md @@ -69,10 +69,10 @@ dtype = torch.float16 num_images_per_prompt = 2 prior_pipeline = WuerstchenPriorPipeline.from_pretrained( - "warp-diffusion/WuerstchenPriorPipeline", torch_dtype=dtype + "warp-diffusion/wuerstchen-prior", torch_dtype=dtype ).to(device) decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained( - "warp-diffusion/WuerstchenDecoderPipeline", torch_dtype=dtype + "warp-diffusion/wuerstchen", torch_dtype=dtype ).to(device) caption = "A captivating artwork of a mysterious stone golem" diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py index 8961f6ed9136..91fd9b79b4ee 100644 --- a/scripts/convert_wuerstchen.py +++ b/scripts/convert_wuerstchen.py @@ -7,8 +7,8 @@ from diffusers import ( DDPMWuerstchenScheduler, - WuerstchenDecoderPipeline, WuerstchenCombinedPipeline, + WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior @@ -91,12 +91,12 @@ prior=prior_model, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler ) -prior_pipeline.save_pretrained("warp-diffusion/WuerstchenPriorPipeline") +prior_pipeline.save_pretrained("warp-diffusion/wuerstchen-prior") decoder_pipeline = WuerstchenDecoderPipeline( text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=deocder, scheduler=scheduler ) -decoder_pipeline.save_pretrained("warp-diffusion/WuerstchenDecoderPipeline") +decoder_pipeline.save_pretrained("warp-diffusion/wuerstchen") # Wuerstchen pipeline wuerstchen_pipeline = WuerstchenCombinedPipeline( diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 9aee4cea4f6f..d72c671671c1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -217,8 +217,8 @@ VersatileDiffusionTextToImagePipeline, VideoToVideoSDPipeline, VQDiffusionPipeline, - WuerstchenDecoderPipeline, WuerstchenCombinedPipeline, + WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 50d7188ce8ba..28f42ce9fae9 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -132,7 +132,7 @@ VersatileDiffusionTextToImagePipeline, ) from .vq_diffusion import VQDiffusionPipeline - from .wuerstchen import WuerstchenDecoderPipeline, WuerstchenCombinedPipeline, WuerstchenPriorPipeline + from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline try: diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 3f7ff6fc5511..13f12e75fb31 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -52,7 +52,7 @@ StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) -from .wuerstchen import WuerstchenDecoderPipeline, WuerstchenCombinedPipeline +from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( @@ -308,8 +308,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): use_auth_token = kwargs.pop("use_auth_token", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - user_agent = kwargs.pop("user_agent", {}) load_config_kwargs = { "cache_dir": cache_dir, @@ -319,8 +317,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): "use_auth_token": use_auth_token, "local_files_only": local_files_only, "revision": revision, - "subfolder": subfolder, - "user_agent": user_agent, } config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) @@ -583,8 +579,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): use_auth_token = kwargs.pop("use_auth_token", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - user_agent = kwargs.pop("user_agent", {}) load_config_kwargs = { "cache_dir": cache_dir, @@ -594,8 +588,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): "use_auth_token": use_auth_token, "local_files_only": local_files_only, "revision": revision, - "subfolder": subfolder, - "user_agent": user_agent, } config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) @@ -859,8 +851,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): use_auth_token = kwargs.pop("use_auth_token", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - user_agent = kwargs.pop("user_agent", {}) load_config_kwargs = { "cache_dir": cache_dir, @@ -870,8 +860,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): "use_auth_token": use_auth_token, "local_files_only": local_files_only, "revision": revision, - "subfolder": subfolder, - "user_agent": user_agent, } config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 604fcbdd8f72..7f95a8abc50b 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -34,10 +34,10 @@ >>> from diffusers import WuerstchenPriorPipeline, WuerstchenDecoderPipeline >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained( - ... "warp-diffusion/WuerstchenPriorPipeline", torch_dtype=torch.float16 + ... "warp-diffusion/wuerstchen-prior", torch_dtype=torch.float16 ... ).to("cuda") >>> gen_pipe = WuerstchenDecoderPipeline.from_pretrain( - ... "warp-diffusion/WuerstchenDecoderPipeline", torch_dtype=torch.float16 + ... "warp-diffusion/wuerstchen", torch_dtype=torch.float16 ... ).to("cuda") >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index 71aed32777b8..76eca8e67980 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -31,9 +31,9 @@ ```py >>> from diffusions import WuerstchenCombinedPipeline - >>> pipe = WuerstchenCombinedPipeline.from_pretrained("warp-diffusion/Wuerstchen", torch_dtype=torch.float16).to( - ... "cuda" - ... ) + >>> pipe = WuerstchenCombinedPipeline.from_pretrained( + ... "warp-diffusion/Wuerstchen", torch_dtype=torch.float16 + ... ).to("cuda") >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" >>> images = pipe(prompt=prompt) ``` diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index c65c2d4acd2a..d24bf070fe5a 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -42,7 +42,7 @@ >>> from diffusers import WuerstchenPriorPipeline >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained( - ... "warp-diffusion/WuerstchenPriorPipeline", torch_dtype=torch.float16 + ... "warp-diffusion/wuerstchen-prior", torch_dtype=torch.float16 ... ).to("cuda") >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" @@ -344,8 +344,8 @@ def __call__( # 3. Determine latent shape of image embeddings dtype = text_encoder_hidden_states.dtype - latent_height = ceil(height / self.resolution_multiple) - latent_width = ceil(width / self.resolution_multiple) + latent_height = ceil(height / self.config.resolution_multiple) + latent_width = ceil(width / self.config.resolution_multiple) num_channels = self.prior.config.c_in effnet_features_shape = (num_images_per_prompt * batch_size, num_channels, latent_height, latent_width) @@ -388,7 +388,7 @@ def __call__( ).prev_sample # 10. Denormalize the latents - latents = latents * self.latent_mean - self.latent_std + latents = latents * self.config.latent_mean - self.config.latent_std if output_type == "np": latents = latents.cpu().numpy() From fb33746898cf63daa89d76f42043d2dfbe35c0c7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Sep 2023 10:32:09 +0000 Subject: [PATCH 173/181] correct docs --- docs/source/en/api/pipelines/wuerstchen.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/en/api/pipelines/wuerstchen.md b/docs/source/en/api/pipelines/wuerstchen.md index cfa16847a4ff..4316bc739ca9 100644 --- a/docs/source/en/api/pipelines/wuerstchen.md +++ b/docs/source/en/api/pipelines/wuerstchen.md @@ -33,14 +33,14 @@ For the sake of usability Würstchen can be used with a single pipeline. This pi ```python import torch -from diffusers import WuerstchenCombinedPipeline +from diffusers import AutoPipelineForText2Image device = "cuda" dtype = torch.float16 num_images_per_prompt = 2 -pipeline = WuerstchenCombinedPipeline.from_pretrained( - "warp-diffusion/WuerstchenCombinedPipeline", torch_dtype=dtype +pipeline = AutoPipelineForText2Image.from_pretrained( + "warp-diffusion/wuerstchen", torch_dtype=dtype ).to(device) caption = "Anthropomorphic cat dressed as a fire fighter" From 63479571422bec2efea153e210fddb900303e02b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Sep 2023 10:32:30 +0000 Subject: [PATCH 174/181] correct docs --- src/diffusers/utils/dummy_torch_and_transformers_objects.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index fd1c897e10e9..5a123c1cd1ee 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1262,7 +1262,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class WuerstchenDecoderPipeline(metaclass=DummyObject): +class WuerstchenCombinedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -1277,7 +1277,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class WuerstchenCombinedPipeline(metaclass=DummyObject): +class WuerstchenDecoderPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): From 5489081bcd61b617330948fa5b56049332004019 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 6 Sep 2023 12:56:14 +0200 Subject: [PATCH 175/181] fix test params --- tests/pipelines/wuerstchen/test_wuerstchen_combined.py | 10 ++++------ tests/pipelines/wuerstchen/test_wuerstchen_prior.py | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py index 4fb2af8e4da0..2abb906801ca 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py @@ -39,11 +39,9 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase "height", "width", "latents", - "decoder_guidance_scale", + "guidance_scale", "negative_prompt", - "decoder_num_inference_steps", "return_dict", - "prior_guidance_scale", "prior_num_inference_steps", "output_type", "return_dict", @@ -145,7 +143,7 @@ def get_dummy_components(self): "decoder": decoder, "vqgan": vqgan, "scheduler": scheduler, - "prior": prior, + "prior_prior": prior, "prior_text_encoder": prior_text_encoder, "prior_tokenizer": tokenizer, "prior_scheduler": scheduler, @@ -161,8 +159,8 @@ def get_dummy_inputs(self, device, seed=0): inputs = { "prompt": "horse", "generator": generator, - "prior_guidance_scale": 4.0, - "decoder_num_inference_steps": 2, + "guidance_scale": 4.0, + "num_inference_steps": 2, "prior_num_inference_steps": 2, "output_type": "np", "height": 128, diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py index 672b5f54d9ca..a255a665c48e 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py @@ -140,7 +140,7 @@ def test_wuerstchen_prior(self): pipe.set_progress_bar_config(disable=None) output = pipe(**self.get_dummy_inputs(device)) - image = output.image_embeds + image = output.image_embeddings image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0] From 30bc6b6b93ce639449a96252b7c1d50de5c8f7fc Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Sep 2023 10:57:36 +0000 Subject: [PATCH 176/181] correct docs --- tests/pipelines/wuerstchen/test_wuerstchen_combined.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py index 4fb2af8e4da0..ff69442403d0 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py @@ -39,9 +39,9 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase "height", "width", "latents", - "decoder_guidance_scale", + "guidance_scale", "negative_prompt", - "decoder_num_inference_steps", + "num_inference_steps", "return_dict", "prior_guidance_scale", "prior_num_inference_steps", @@ -145,7 +145,7 @@ def get_dummy_components(self): "decoder": decoder, "vqgan": vqgan, "scheduler": scheduler, - "prior": prior, + "prior_prior": prior, "prior_text_encoder": prior_text_encoder, "prior_tokenizer": tokenizer, "prior_scheduler": scheduler, @@ -162,7 +162,7 @@ def get_dummy_inputs(self, device, seed=0): "prompt": "horse", "generator": generator, "prior_guidance_scale": 4.0, - "decoder_num_inference_steps": 2, + "num_inference_steps": 2, "prior_num_inference_steps": 2, "output_type": "np", "height": 128, From 09787b19308e8f64573c594778a3b8722fad068e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Sep 2023 14:27:45 +0200 Subject: [PATCH 177/181] fix classifier free guidance --- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 7 +++++-- .../pipelines/wuerstchen/pipeline_wuerstchen_prior.py | 5 ++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 7f95a8abc50b..26f217499a24 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -166,6 +166,7 @@ def encode_prompt( text_encoder_hidden_states = text_encoder_output.last_hidden_state text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_text_encoder_hidden_states = None if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: @@ -210,7 +211,7 @@ def encode_prompt( # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - return text_encoder_hidden_states, uncond_text_encoder_hidden_states + return text_encoder_hidden_states, uncond_text_encoder_hidden_states def check_inputs( self, @@ -330,7 +331,9 @@ def __call__( prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) - text_encoder_hidden_states = torch.cat([prompt_embeds, negative_prompt_embeds]) + text_encoder_hidden_states = ( + torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds + ) # 3. Determine latent shape of latents latent_height = int(image_embeddings.size(2) * self.config.latent_dim_scale) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index d24bf070fe5a..8b13d8fdf2b7 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -181,6 +181,7 @@ def encode_prompt( text_encoder_hidden_states = text_encoder_output.last_hidden_state text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_text_encoder_hidden_states = None if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: @@ -340,7 +341,9 @@ def __call__( # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_encoder_hidden_states = torch.cat([prompt_embeds, negative_prompt_embeds]) + text_encoder_hidden_states = ( + torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds + ) # 3. Determine latent shape of image embeddings dtype = text_encoder_hidden_states.dtype From ed9f96a8ffb4d28ce49b005abed9b36b29afe609 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Sep 2023 14:31:15 +0200 Subject: [PATCH 178/181] fix classifier free guidance --- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 26f217499a24..9f0a58e80540 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -382,6 +382,7 @@ def __call__( ).prev_sample # 10. Scale and decode the image latents with vq-vae + import ipdb; ipdb.set_trace() latents = self.vqgan.config.scale_factor * latents images = self.vqgan.decode(latents).sample.clamp(0, 1) From 30a86b3cdd0a9757d6148a1e99923a9559231583 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Sep 2023 14:34:34 +0200 Subject: [PATCH 179/181] fix more --- src/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py | 3 +++ src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py b/src/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py index e33d2e576c01..09bdd16592df 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py @@ -22,6 +22,7 @@ from ...models.modeling_utils import ModelMixin from ...models.vae import DecoderOutput, VectorQuantizer from ...models.vq_model import VQEncoderOutput +from ...utils import apply_forward_hook class MixingResidualBlock(nn.Module): @@ -128,6 +129,7 @@ def __init__( nn.PixelShuffle(up_down_scale_factor), ) + @apply_forward_hook def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: h = self.in_block(x) h = self.down_blocks(h) @@ -137,6 +139,7 @@ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOut return VQEncoderOutput(latents=h) + @apply_forward_hook def decode( self, h: torch.FloatTensor, force_not_quantize: bool = True, return_dict: bool = True ) -> Union[DecoderOutput, torch.FloatTensor]: diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 9f0a58e80540..26f217499a24 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -382,7 +382,6 @@ def __call__( ).prev_sample # 10. Scale and decode the image latents with vq-vae - import ipdb; ipdb.set_trace() latents = self.vqgan.config.scale_factor * latents images = self.vqgan.decode(latents).sample.clamp(0, 1) From 3f04adaf77b38eb92577027ff419af4a652fc219 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Sep 2023 14:39:51 +0200 Subject: [PATCH 180/181] fix all --- src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py | 2 +- .../pipelines/wuerstchen/pipeline_wuerstchen_combined.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 26f217499a24..78aeebed7943 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -338,7 +338,7 @@ def __call__( # 3. Determine latent shape of latents latent_height = int(image_embeddings.size(2) * self.config.latent_dim_scale) latent_width = int(image_embeddings.size(3) * self.config.latent_dim_scale) - latent_features_shape = (image_embeddings.size(0), 4, latent_height, latent_width) + latent_features_shape = (image_embeddings.size(0) * num_images_per_prompt, 4, latent_height, latent_width) # 4. Prepare and set timesteps if timesteps is not None: diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index 76eca8e67980..ff4c31686bf5 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -242,7 +242,6 @@ def __call__( timesteps=timesteps, generator=generator, guidance_scale=guidance_scale, - num_images_per_prompt=num_images_per_prompt, output_type=output_type, return_dict=return_dict, ) From c35f3f7cde446745b75c8fb92f39441d0edcebd9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Sep 2023 15:25:33 +0200 Subject: [PATCH 181/181] make tests faster --- tests/pipelines/test_pipelines_common.py | 2 ++ tests/pipelines/wuerstchen/test_wuerstchen_decoder.py | 7 ++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 319dcb5aab32..a6f828443cb0 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -535,6 +535,8 @@ def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4): def test_components_function(self): init_components = self.get_dummy_components() + init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))} + pipe = self.pipeline_class(**init_components) self.assertTrue(hasattr(pipe, "components")) diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py index 1d843a17f57b..709e2c1a3436 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py @@ -125,6 +125,7 @@ def get_dummy_components(self): "text_encoder": text_encoder, "tokenizer": tokenizer, "scheduler": scheduler, + "latent_dim_scale": 4.0, } return components @@ -135,7 +136,7 @@ def get_dummy_inputs(self, device, seed=0): else: generator = torch.Generator(device=device).manual_seed(seed) inputs = { - "image_embeddings": torch.ones((1, 16, 10, 10), device=device), + "image_embeddings": torch.ones((1, 4, 4, 4), device=device), "prompt": "horse", "generator": generator, "guidance_scale": 1.0, @@ -162,9 +163,9 @@ def test_wuerstchen_decoder(self): image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 424, 424, 3) + assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0]) + expected_slice = np.array([0.0000, 0.0000, 0.0089, 1.0000, 1.0000, 0.3927, 1.0000, 1.0000, 1.0000]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2