From c805393a2aea278f06866865564ece35e8967abe Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Sat, 13 May 2023 15:56:33 +0530 Subject: [PATCH 1/6] Add img2img --- src/diffusers/pipelines/kandinsky/__init__.py | 1 + .../kandinsky/pipeline_kandinsky_img2img.py | 393 ++++++++++++++++++ 2 files changed, 394 insertions(+) create mode 100644 src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py diff --git a/src/diffusers/pipelines/kandinsky/__init__.py b/src/diffusers/pipelines/kandinsky/__init__.py index 93731dbdbf3d..f441c7d1abeb 100644 --- a/src/diffusers/pipelines/kandinsky/__init__.py +++ b/src/diffusers/pipelines/kandinsky/__init__.py @@ -15,5 +15,6 @@ from .pipeline_kandinsky import KandinskyPipeline from .pipeline_kandinsky_prior import KandinskyPriorPipeline from .pipeline_kandinsky_inpaint import KandinskyInpaintPipeline + from .pipeline_kandinsky_img2img import KandinskyImg2ImgPipeline from .text_encoder import MultilingualCLIP from .text_proj import KandinskyTextProjModel diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py new file mode 100644 index 000000000000..c91949c3f9f7 --- /dev/null +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py @@ -0,0 +1,393 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch +from transformers import ( + XLMRobertaTokenizerFast, +) + +from ...models import UNet2DConditionModel, VQModel +from ...pipelines import DiffusionPipeline +from ...schedulers import UnCLIPScheduler +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, +) +from .text_encoder import MultilingualCLIP +from .text_proj import KandinskyTextProjModel +import PIL + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_new_h_w(h, w): + new_h = h // 64 + if h % 64 != 0: + new_h += 1 + new_w = w // 64 + if w % 64 != 0: + new_w += 1 + return new_h * 8, new_w * 8 + + +class KandinskyImg2ImgPipeline(DiffusionPipeline): + """ + Pipeline for image-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`MultilingualCLIP`]): + Frozen text-encoder. + tokenizer ([`XLMRobertaTokenizerFast`]): + Tokenizer of class + scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + text_proj ([`KandinskyTextProjModel`]): + Utility class to prepare and combine the embeddings before they are passed to the decoder. + """ + + def __init__( + self, + text_encoder: MultilingualCLIP, + tokenizer: XLMRobertaTokenizerFast, + text_proj: KandinskyTextProjModel, + unet: UNet2DConditionModel, + scheduler: UnCLIPScheduler, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + text_proj=text_proj, + unet=unet, + scheduler=scheduler, + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, latents, latent_timestep, shape, dtype, device, generator, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + latents = latents * scheduler.init_noise_sigma + + shape = latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, latent_timestep) + latents = init_latents + + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + text_mask = text_inputs.attention_mask.to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + prompt_embeds, text_encoder_hidden_states = self.text_encoder( + input_ids=text_input_ids, attention_mask=text_mask + ) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + uncond_text_input_ids = uncond_input.input_ids.to(device) + uncond_text_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( + input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask + ) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + strength: float = 0.75, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + image_embeds: Optional[torch.FloatTensor] = None, + negative_image_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(device) + + text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( + image_embeddings=image_embeds, + prompt_embeds=prompt_embeds, + text_encoder_hidden_states=text_encoder_hidden_states, + ) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps_tensor, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps_tensor[:1].repeat(batch_size * num_images_per_prompt) + + num_channels_latents = self.unet.config.in_channels + + height, width = get_new_h_w(height, width) + + # create initial latent + latents = self.prepare_latents( + image, + latent_timestep, + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + self.scheduler, + ) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps_tensor): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + noise_pred = self.unet( + sample=latent_model_input, # [2, 4, 96, 96] + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + class_labels=additive_clip_time_embeddings, + ).sample + + # YiYi Notes: CFG is currently implemented exactly as original repo as a baseline, + # i.e. we apply cfg to predicted noise, and take predicted variance as it is (uncond + cond) + # this means the our latent shape is batch_size *2 instad batch_size + + if do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + variance_pred_uncond, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred] * 2) + variance_pred = torch.cat([variance_pred_uncond, variance_pred_text]) + noise_pred = torch.cat([noise_pred, variance_pred], dim=1) + + if i + 1 == timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + prev_timestep=prev_timestep, + generator=generator, + batch_size=batch_size, + ).prev_sample + + _, latents = latents.chunk(2) + + return latents From 25d5700762ea2a13fbb677bd1b56a4c61dd5d338 Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Sun, 14 May 2023 18:29:51 +0530 Subject: [PATCH 2/6] Add DDPM scheduler and image encoding/processing --- .../kandinsky/pipeline_kandinsky_img2img.py | 46 +++++++++++++++---- 1 file changed, 36 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py index c91949c3f9f7..86e7a5e130bf 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py @@ -21,7 +21,8 @@ from ...models import UNet2DConditionModel, VQModel from ...pipelines import DiffusionPipeline -from ...schedulers import UnCLIPScheduler +from ...pipelines.pipeline_utils import ImagePipelineOutput +from ...schedulers import DDPMScheduler from ...utils import ( is_accelerate_available, is_accelerate_version, @@ -31,6 +32,8 @@ from .text_encoder import MultilingualCLIP from .text_proj import KandinskyTextProjModel import PIL +from PIL import Image +import numpy as np logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -44,6 +47,13 @@ def get_new_h_w(h, w): new_w += 1 return new_h * 8, new_w * 8 +def prepare_image(pil_image, w=512, h=512): + pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) + arr = np.array(pil_image.convert("RGB")) + arr = arr.astype(np.float32) / 127.5 - 1 + arr = np.transpose(arr, [2, 0, 1]) + image = torch.from_numpy(arr).unsqueeze(0) + return image class KandinskyImg2ImgPipeline(DiffusionPipeline): """ @@ -71,7 +81,8 @@ def __init__( tokenizer: XLMRobertaTokenizerFast, text_proj: KandinskyTextProjModel, unet: UNet2DConditionModel, - scheduler: UnCLIPScheduler, + scheduler: DDPMScheduler, + movq: VQModel ): super().__init__() @@ -81,6 +92,7 @@ def __init__( text_proj=text_proj, unet=unet, scheduler=scheduler, + movq=movq ) def get_timesteps(self, num_inference_steps, strength, device): @@ -88,7 +100,7 @@ def get_timesteps(self, num_inference_steps, strength, device): init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start @@ -99,15 +111,14 @@ def prepare_latents(self, latents, latent_timestep, shape, dtype, device, genera if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) + latents = latents * scheduler.init_noise_sigma shape = latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents - init_latents = self.scheduler.add_noise(init_latents, noise, latent_timestep) - latents = init_latents - + latents = self.scheduler.add_noise(latents, noise, latent_timestep) return latents def _encode_prompt( @@ -329,6 +340,9 @@ def __call__( text_encoder_hidden_states=text_encoder_hidden_states, ) + image = prepare_image(image, width, height).to(device) + latents = self.movq.encode(image)["latents"] + self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps_tensor, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps_tensor[:1].repeat(batch_size * num_images_per_prompt) @@ -339,7 +353,7 @@ def __call__( # create initial latent latents = self.prepare_latents( - image, + latents, latent_timestep, (batch_size, num_channels_latents, height, width), text_encoder_hidden_states.dtype, @@ -383,11 +397,23 @@ def __call__( noise_pred, t, latents, - prev_timestep=prev_timestep, generator=generator, - batch_size=batch_size, ).prev_sample _, latents = latents.chunk(2) - return latents + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + From 3ec09a958b1093882ec0b2e2e0fc7a00792d2d61 Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Sun, 14 May 2023 18:48:25 +0530 Subject: [PATCH 3/6] Fix import --- src/diffusers/__init__.py | 1 + src/diffusers/pipelines/__init__.py | 2 +- .../dummy_torch_and_transformers_objects.py | 17 +++++++++++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 53e70a96928e..d2f1afd29f08 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -129,6 +129,7 @@ IFPipeline, IFSuperResolutionPipeline, KandinskyInpaintPipeline, + KandinskyImg2ImgPipeline, KandinskyPipeline, KandinskyPriorPipeline, LDMTextToImagePipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ba906e22300d..03a4d9c7d371 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -52,7 +52,7 @@ IFPipeline, IFSuperResolutionPipeline, ) - from .kandinsky import KandinskyInpaintPipeline, KandinskyPipeline, KandinskyPriorPipeline + from .kandinsky import KandinskyInpaintPipeline, KandinskyPipeline, KandinskyPriorPipeline, KandinskyImg2ImgPipeline from .latent_diffusion import LDMTextToImagePipeline from .paint_by_example import PaintByExamplePipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index cdd61dcf2ac5..1f8bd6a3d2cb 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -152,6 +152,23 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) + +class KandinskyImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + + class KandinskyInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 79c818b369d57512079a18e6a569fac61fe1c489 Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Tue, 16 May 2023 12:53:19 +0530 Subject: [PATCH 4/6] Add DDPM to text2img --- src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py index 3ff4c1287cd7..ea0c64326e65 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py @@ -22,7 +22,7 @@ from ...models import UNet2DConditionModel, VQModel from ...pipelines import DiffusionPipeline from ...pipelines.pipeline_utils import ImagePipelineOutput -from ...schedulers import UnCLIPScheduler +from ...schedulers import DDPMScheduler from ...utils import ( is_accelerate_available, is_accelerate_version, @@ -58,7 +58,7 @@ class KandinskyPipeline(DiffusionPipeline): Frozen text-encoder. tokenizer ([`XLMRobertaTokenizerFast`]): Tokenizer of class - scheduler ([`UnCLIPScheduler`]): + scheduler ([`DDPMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. @@ -74,7 +74,7 @@ def __init__( tokenizer: XLMRobertaTokenizerFast, text_proj: KandinskyTextProjModel, unet: UNet2DConditionModel, - scheduler: UnCLIPScheduler, + scheduler: DDPMScheduler, movq: VQModel, ): super().__init__() @@ -376,9 +376,7 @@ def __call__( noise_pred, t, latent_model_input, - prev_timestep=prev_timestep, generator=generator, - batch_size=batch_size, ).prev_sample _, latents = latents.chunk(2) From 05a5b068d3885e126fc73eb837d742785afea403 Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Wed, 17 May 2023 17:56:08 +0530 Subject: [PATCH 5/6] Add img2img tests --- .../kandinsky/pipeline_kandinsky_img2img.py | 19 +- .../kandinsky/test_kandinsky_img2img.py | 307 ++++++++++++++++++ 2 files changed, 317 insertions(+), 9 deletions(-) create mode 100644 tests/pipelines/kandinsky/test_kandinsky_img2img.py diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py index 86e7a5e130bf..5147b9b6b4d3 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py @@ -38,14 +38,14 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def get_new_h_w(h, w): - new_h = h // 64 - if h % 64 != 0: +def get_new_h_w(h, w, scale_factor=8): + new_h = h // scale_factor**2 + if h % scale_factor**2 != 0: new_h += 1 - new_w = w // 64 - if w % 64 != 0: + new_w = w // scale_factor**2 + if w % scale_factor**2 != 0: new_w += 1 - return new_h * 8, new_w * 8 + return new_h * scale_factor, new_w * scale_factor def prepare_image(pil_image, w=512, h=512): pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) @@ -67,7 +67,7 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline): Frozen text-encoder. tokenizer ([`XLMRobertaTokenizerFast`]): Tokenizer of class - scheduler ([`UnCLIPScheduler`]): + scheduler ([`DDPMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. @@ -94,6 +94,7 @@ def __init__( scheduler=scheduler, movq=movq ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep @@ -347,9 +348,9 @@ def __call__( timesteps_tensor, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps_tensor[:1].repeat(batch_size * num_images_per_prompt) - num_channels_latents = self.unet.config.in_channels + num_channels_latents = self.movq.config.latent_channels - height, width = get_new_h_w(height, width) + height, width = get_new_h_w(height, width, self.movq_scale_factor) # create initial latent latents = self.prepare_latents( diff --git a/tests/pipelines/kandinsky/test_kandinsky_img2img.py b/tests/pipelines/kandinsky/test_kandinsky_img2img.py new file mode 100644 index 000000000000..47ac7fa667e9 --- /dev/null +++ b/tests/pipelines/kandinsky/test_kandinsky_img2img.py @@ -0,0 +1,307 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import XLMRobertaTokenizer + +from diffusers import KandinskyImg2ImgPipeline, KandinskyPriorPipeline, DDPMScheduler, UNet2DConditionModel, VQModel, KandinskyImg2ImgPipeline +from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP +from diffusers.pipelines.kandinsky.text_proj import KandinskyTextProjModel +from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu + +from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference + + +torch.backends.cuda.matmul.allow_tf32 = False +torch.use_deterministic_algorithms(True) + + +class KandinskyImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = KandinskyImg2ImgPipeline + params = ["prompt", "image_embeds", "negative_image_embeds", "image"] + batch_params = [ + "prompt", + "negative_prompt", + "image_embeds", + "negative_image_embeds", + "image", + ] + required_optional_params = [ + "generator", + "height", + "width", + "strength", + "guidance_scale", + "negative_prompt", + "num_inference_steps", + "return_dict", + "guidance_scale", + "num_images_per_prompt", + "output_type", + "return_dict", + ] + test_xformers_attention = False + + @property + def text_embedder_hidden_size(self): + return 32 + + @property + def time_input_dim(self): + return 32 + + @property + def block_out_channels_0(self): + return self.time_input_dim + + @property + def time_embed_dim(self): + return self.time_input_dim * 4 + + @property + def cross_attention_dim(self): + return 100 + + @property + def dummy_tokenizer(self): + tokenizer = XLMRobertaTokenizer.from_pretrained("YiYiXu/Kandinsky", subfolder="tokenizer") + return tokenizer + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = MCLIPConfig( + numDims=self.cross_attention_dim, + transformerDimensions=self.text_embedder_hidden_size, + hidden_size=self.text_embedder_hidden_size, + intermediate_size=37, + num_attention_heads=4, + num_hidden_layers=5, + vocab_size=250002, + ) + + text_encoder = MultilingualCLIP(config) + text_encoder = text_encoder.eval() + + return text_encoder + + @property + def dummy_text_proj(self): + torch.manual_seed(0) + + model_kwargs = { + "clip_embeddings_dim": self.cross_attention_dim, + "time_embed_dim": self.time_embed_dim, + "clip_extra_context_tokens": 2, + "cross_attention_dim": self.cross_attention_dim, + "clip_text_encoder_hidden_states_dim": self.text_embedder_hidden_size, + } + + model = KandinskyTextProjModel(**model_kwargs) + return model + + @property + def dummy_unet(self): + torch.manual_seed(0) + + model_kwargs = { + "in_channels": 4, + # Out channels is double in channels because predicts mean and variance + "out_channels": 8, + "down_block_types": ("ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D"), + "up_block_types": ("SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"), + "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", + "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2), + "layers_per_block": 1, + "cross_attention_dim": self.cross_attention_dim, + "attention_head_dim": 4, + "resnet_time_scale_shift": "scale_shift", + "class_embed_type": "identity", + } + + model = UNet2DConditionModel(**model_kwargs) + return model + + @property + def dummy_movq_kwargs(self): + return { + "block_out_channels": [32, 64], + "down_block_types": ["DownEncoderBlock2D", "AttnDownEncoderBlock2D"], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 1, + "norm_num_groups": 8, + "norm_type": "spatial", + "num_vq_embeddings": 12, + "out_channels": 3, + "up_block_types": [ + "AttnUpDecoderBlock2D", + "UpDecoderBlock2D", + ], + "vq_embed_dim": 4, + } + + @property + def dummy_movq(self): + torch.manual_seed(0) + model = VQModel(**self.dummy_movq_kwargs) + return model + + def get_dummy_components(self): + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + unet = self.dummy_unet + text_proj = self.dummy_text_proj + movq = self.dummy_movq + + ddpm_config = { + "clip_sample": True, + "clip_sample_range": 2.0, + "sample_max_value": None, + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "variance_type": "learned_range", + "thresholding": True, + "beta_schedule": "linear", + "beta_start": 0.00085, + "beta_end":0.012 + } + + scheduler = DDPMScheduler(**ddpm_config) + + components = { + "text_proj": text_proj, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "unet": unet, + "scheduler": scheduler, + "movq": movq, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + image_embeds = floats_tensor((1, self.cross_attention_dim), rng=random.Random(seed)).to(device) + negative_image_embeds = floats_tensor((1, self.cross_attention_dim), rng=random.Random(seed + 1)).to(device) + # create init_image + image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((256, 256)) + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "horse", + "image": init_image, + "image_embeds": image_embeds, + "negative_image_embeds": negative_image_embeds, + "generator": generator, + "height": 64, + "width": 64, + "num_inference_steps": 2, + "output_type": "np", + } + return inputs + + def test_kandinsky_img2img(self): + device = "cpu" + + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(device)) + image = output.images + + image_from_tuple = pipe( + **self.get_dummy_inputs(device), + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + print(f"image.shape {image.shape}") + + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array([0.6635241 , 0.6152489 , 0.5687914 , 0.57371366, 0.53458804, 0.47828954, 0.5454488 , 0.51518494, 0.49540082]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + +@slow +@require_torch_gpu +class KandinskyImg2ImgPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_kandinsky_img2img(self): + expected_image = load_numpy( +"https://user-images.githubusercontent.com/43698245/238191310-a99fe3cf-c2ee-417e-94f1-c1829a0ae0a3.png + ) + + init_image = load_image( + "https://user-images.githubusercontent.com/43698245/238191310-a99fe3cf-c2ee-417e-94f1-c1829a0ae0a3.png" + ) + prompt = "A red cartoon frog, 4k" + + pipe_prior = KandinskyPriorPipeline.from_pretrained("YiYiXu/Kandinsky-prior", torch_dtype=torch.float16) + pipe_prior.to(torch_device) + + pipeline = KandinskyImg2ImgPipeline.from_pretrained("ayushtues/test-kandinsky-img2img", torch_dtype=torch.float16) + pipeline = pipeline.to(torch_device) + pipeline.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + image_emb = pipe_prior( + prompt, + generator=generator, + ).images + zero_image_emb = pipe_prior("").images + + output = pipeline( + prompt, + image=init_image, + image_embeds=image_emb, + negative_image_embeds=zero_image_emb, + generator=generator, + num_inference_steps=100, + height=768, + width=768, + output_type="np", + ) + + image = output.images[0] + + assert image.shape == (768, 768, 3) + + assert_mean_pixel_difference(image, expected_image) From 3788c502fd7dea10fe728921403384c573f808a2 Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Wed, 17 May 2023 18:19:55 +0530 Subject: [PATCH 6/6] Add expected image url --- tests/pipelines/kandinsky/test_kandinsky_img2img.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/kandinsky/test_kandinsky_img2img.py b/tests/pipelines/kandinsky/test_kandinsky_img2img.py index 47ac7fa667e9..a93e0d3a32ea 100644 --- a/tests/pipelines/kandinsky/test_kandinsky_img2img.py +++ b/tests/pipelines/kandinsky/test_kandinsky_img2img.py @@ -266,7 +266,7 @@ def tearDown(self): def test_kandinsky_img2img(self): expected_image = load_numpy( -"https://user-images.githubusercontent.com/43698245/238191310-a99fe3cf-c2ee-417e-94f1-c1829a0ae0a3.png + "https://user-images.githubusercontent.com/43698245/238954026-6c3e3da6-ef18-4d78-b521-6386e6922444.png" ) init_image = load_image( @@ -297,6 +297,7 @@ def test_kandinsky_img2img(self): num_inference_steps=100, height=768, width=768, + strength=0.2, output_type="np", )