diff --git a/docs/source/api/pipelines/unclip.mdx b/docs/source/api/pipelines/unclip.mdx index 0d2e17601261..c744a567b1bc 100644 --- a/docs/source/api/pipelines/unclip.mdx +++ b/docs/source/api/pipelines/unclip.mdx @@ -28,4 +28,6 @@ The unCLIP model in diffusers comes from kakaobrain's karlo and the original cod ## UnCLIPPipeline [[autodoc]] pipelines.unclip.pipeline_unclip.UnCLIPPipeline + - __call__ +[[autodoc]] pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline - __call__ \ No newline at end of file diff --git a/scripts/convert_unclip_txt2img_to_image_variation.py b/scripts/convert_unclip_txt2img_to_image_variation.py new file mode 100644 index 000000000000..d228a537ed4c --- /dev/null +++ b/scripts/convert_unclip_txt2img_to_image_variation.py @@ -0,0 +1,40 @@ +import argparse + +from diffusers import UnCLIPImageVariationPipeline, UnCLIPPipeline +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + parser.add_argument( + "--txt2img_unclip", + default="kakaobrain/karlo-v1-alpha", + type=str, + required=False, + help="The pretrained txt2img unclip.", + ) + + args = parser.parse_args() + + txt2img = UnCLIPPipeline.from_pretrained(args.txt2img_unclip) + + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + + img2img = UnCLIPImageVariationPipeline( + decoder=txt2img.decoder, + text_encoder=txt2img.text_encoder, + tokenizer=txt2img.tokenizer, + text_proj=txt2img.text_proj, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + super_res_first=txt2img.super_res_first, + super_res_last=txt2img.super_res_last, + decoder_scheduler=txt2img.decoder_scheduler, + super_res_scheduler=txt2img.super_res_scheduler, + ) + + img2img.save_pretrained(args.dump_path) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 10e7d560b147..7c158a4fc0e1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -105,6 +105,7 @@ StableDiffusionPipeline, StableDiffusionPipelineSafe, StableDiffusionUpscalePipeline, + UnCLIPImageVariationPipeline, UnCLIPPipeline, VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 7cecfe569234..ced96e60f6f2 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -53,7 +53,7 @@ StableDiffusionUpscalePipeline, ) from .stable_diffusion_safe import StableDiffusionPipelineSafe - from .unclip import UnCLIPPipeline + from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline from .versatile_diffusion import ( VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, diff --git a/src/diffusers/pipelines/unclip/__init__.py b/src/diffusers/pipelines/unclip/__init__.py index c495367bc770..23b54a7d2f79 100644 --- a/src/diffusers/pipelines/unclip/__init__.py +++ b/src/diffusers/pipelines/unclip/__init__.py @@ -13,4 +13,5 @@ from ...utils.dummy_torch_and_transformers_objects import UnCLIPPipeline else: from .pipeline_unclip import UnCLIPPipeline + from .pipeline_unclip_image_variation import UnCLIPImageVariationPipeline from .text_proj import UnCLIPTextProjModel diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py index 5dc8ed3a89e9..a8e8cba8375d 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py @@ -45,6 +45,8 @@ class UnCLIPPipeline(DiffusionPipeline): [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). prior ([`PriorTransformer`]): The canonincal unCLIP prior to approximate the image embedding from the text embedding. + text_proj ([`UnCLIPTextProjModel`]): + Utility class to prepare and combine the embeddings before they are passed to the decoder. decoder ([`UNet2DConditionModel`]): The decoder to invert the image embedding into an image. super_res_first ([`UNet2DModel`]): diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py new file mode 100644 index 000000000000..32b950397a32 --- /dev/null +++ b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py @@ -0,0 +1,456 @@ +# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Union + +import torch +from torch.nn import functional as F + +import PIL +from diffusers import UNet2DConditionModel, UNet2DModel +from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.schedulers import UnCLIPScheduler +from transformers import ( + CLIPFeatureExtractor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...utils import is_accelerate_available, logging +from .text_proj import UnCLIPTextProjModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class UnCLIPImageVariationPipeline(DiffusionPipeline): + """ + Pipeline to generate variations from an input image using unCLIP + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `image_encoder`. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_proj ([`UnCLIPTextProjModel`]): + Utility class to prepare and combine the embeddings before they are passed to the decoder. + decoder ([`UNet2DConditionModel`]): + The decoder to invert the image embedding into an image. + super_res_first ([`UNet2DModel`]): + Super resolution unet. Used in all but the last step of the super resolution diffusion process. + super_res_last ([`UNet2DModel`]): + Super resolution unet. Used in the last step of the super resolution diffusion process. + decoder_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the decoder denoising process. Just a modified DDPMScheduler. + super_res_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler. + + """ + + decoder: UNet2DConditionModel + text_proj: UnCLIPTextProjModel + text_encoder: CLIPTextModelWithProjection + tokenizer: CLIPTokenizer + feature_extractor: CLIPFeatureExtractor + image_encoder: CLIPVisionModelWithProjection + super_res_first: UNet2DModel + super_res_last: UNet2DModel + + decoder_scheduler: UnCLIPScheduler + super_res_scheduler: UnCLIPScheduler + + def __init__( + self, + decoder: UNet2DConditionModel, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_proj: UnCLIPTextProjModel, + feature_extractor: CLIPFeatureExtractor, + image_encoder: CLIPVisionModelWithProjection, + super_res_first: UNet2DModel, + super_res_last: UNet2DModel, + decoder_scheduler: UnCLIPScheduler, + super_res_scheduler: UnCLIPScheduler, + ): + super().__init__() + + self.register_modules( + decoder=decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_proj=text_proj, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + super_res_first=super_res_first, + super_res_last=super_res_last, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + text_embeddings = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return text_embeddings, text_encoder_hidden_states, text_mask + + def _encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + + image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + + return image_embeddings + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [ + self.decoder, + self.text_proj, + self.text_encoder, + self.super_res_first, + self.super_res_last, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"): + return self.device + for module in self.decoder.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @torch.no_grad() + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], + num_images_per_prompt: int = 1, + decoder_num_inference_steps: int = 25, + super_res_num_inference_steps: int = 7, + generator: Optional[torch.Generator] = None, + decoder_latents: Optional[torch.FloatTensor] = None, + super_res_latents: Optional[torch.FloatTensor] = None, + decoder_guidance_scale: float = 8.0, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The image or images to guide the image generation. If you provide a tensor, it needs to comply with the + configuration of + [this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) + `CLIPFeatureExtractor`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + decoder_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + super_res_num_inference_steps (`int`, *optional*, defaults to 7): + The number of denoising steps for super resolution. More denoising steps usually lead to a higher + quality image at the expense of slower inference. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + decoder_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + """ + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + + prompt = [""] * batch_size + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = decoder_guidance_scale > 1.0 + + text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance + ) + + image_embeddings = self._encode_image(image, device, num_images_per_prompt) + + # decoder + + text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( + image_embeddings=image_embeddings, + text_embeddings=text_embeddings, + text_encoder_hidden_states=text_encoder_hidden_states, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) + + self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) + decoder_timesteps_tensor = self.decoder_scheduler.timesteps + + num_channels_latents = self.decoder.in_channels + height = self.decoder.sample_size + width = self.decoder.sample_size + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + decoder_latents, + self.decoder_scheduler, + ) + + for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents + + noise_pred = self.decoder( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + class_labels=additive_clip_time_embeddings, + attention_mask=decoder_text_mask, + ).sample + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if i + 1 == decoder_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = decoder_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + decoder_latents = self.decoder_scheduler.step( + noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + decoder_latents = decoder_latents.clamp(-1, 1) + + image_small = decoder_latents + + # done decoder + + # super res + + self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) + super_res_timesteps_tensor = self.super_res_scheduler.timesteps + + channels = self.super_res_first.in_channels // 2 + height = self.super_res_first.sample_size + width = self.super_res_first.sample_size + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + device, + generator, + super_res_latents, + self.super_res_scheduler, + ) + + interpolate_antialias = {} + if "antialias" in inspect.signature(F.interpolate).parameters: + interpolate_antialias["antialias"] = True + + image_upscaled = F.interpolate( + image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias + ) + + for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): + # no classifier free guidance + + if i == super_res_timesteps_tensor.shape[0] - 1: + unet = self.super_res_last + else: + unet = self.super_res_first + + latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) + + noise_pred = unet( + sample=latent_model_input, + timestep=t, + ).sample + + if i + 1 == super_res_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = super_res_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + super_res_latents = self.super_res_scheduler.step( + noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + image = super_res_latents + + # done super res + + # post processing + + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index ba2798c784ef..25f347be9021 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -199,6 +199,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class UnCLIPImageVariationPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class UnCLIPPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/unclip/test_unclip.py b/tests/pipelines/unclip/test_unclip.py index c1f67e557fd9..fb0cb75ea703 100644 --- a/tests/pipelines/unclip/test_unclip.py +++ b/tests/pipelines/unclip/test_unclip.py @@ -281,7 +281,7 @@ def test_unclip_karlo(self): assert image.shape == (256, 256, 3) assert np.abs(expected_image - image).max() < 1e-2 - def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): + def test_unclip_pipeline_with_sequential_cpu_offloading(self): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() diff --git a/tests/pipelines/unclip/test_unclip_image_variation.py b/tests/pipelines/unclip/test_unclip_image_variation.py new file mode 100644 index 000000000000..64b51536084b --- /dev/null +++ b/tests/pipelines/unclip/test_unclip_image_variation.py @@ -0,0 +1,443 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch + +from diffusers import UnCLIPImageVariationPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel +from diffusers.utils import floats_tensor, load_numpy, slow, torch_device +from diffusers.utils.testing_utils import load_image, require_torch_gpu +from transformers import ( + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class UnCLIPImageVariationPipelineFastTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def text_embedder_hidden_size(self): + return 32 + + @property + def time_input_dim(self): + return 32 + + @property + def block_out_channels_0(self): + return self.time_input_dim + + @property + def time_embed_dim(self): + return self.time_input_dim * 4 + + @property + def cross_attention_dim(self): + return 100 + + @property + def dummy_tokenizer(self): + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + return tokenizer + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=self.text_embedder_hidden_size, + projection_dim=self.text_embedder_hidden_size, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModelWithProjection(config) + + @property + def dummy_image_encoder(self): + torch.manual_seed(0) + config = CLIPVisionConfig( + hidden_size=self.text_embedder_hidden_size, + projection_dim=self.text_embedder_hidden_size, + num_hidden_layers=5, + num_attention_heads=4, + image_size=32, + intermediate_size=37, + patch_size=1, + ) + return CLIPVisionModelWithProjection(config) + + @property + def dummy_text_proj(self): + torch.manual_seed(0) + + model_kwargs = { + "clip_embeddings_dim": self.text_embedder_hidden_size, + "time_embed_dim": self.time_embed_dim, + "cross_attention_dim": self.cross_attention_dim, + } + + model = UnCLIPTextProjModel(**model_kwargs) + return model + + @property + def dummy_decoder(self): + torch.manual_seed(0) + + model_kwargs = { + "sample_size": 64, + # RGB in channels + "in_channels": 3, + # Out channels is double in channels because predicts mean and variance + "out_channels": 6, + "down_block_types": ("ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D"), + "up_block_types": ("SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"), + "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", + "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2), + "layers_per_block": 1, + "cross_attention_dim": self.cross_attention_dim, + "attention_head_dim": 4, + "resnet_time_scale_shift": "scale_shift", + "class_embed_type": "identity", + } + + model = UNet2DConditionModel(**model_kwargs) + return model + + @property + def dummy_super_res_kwargs(self): + return { + "sample_size": 128, + "layers_per_block": 1, + "down_block_types": ("ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D"), + "up_block_types": ("ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D"), + "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2), + "in_channels": 6, + "out_channels": 3, + } + + @property + def dummy_super_res_first(self): + torch.manual_seed(0) + + model = UNet2DModel(**self.dummy_super_res_kwargs) + return model + + @property + def dummy_super_res_last(self): + # seeded differently to get different unet than `self.dummy_super_res_first` + torch.manual_seed(1) + + model = UNet2DModel(**self.dummy_super_res_kwargs) + return model + + def get_pipeline(self, device): + decoder = self.dummy_decoder + text_proj = self.dummy_text_proj + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + super_res_first = self.dummy_super_res_first + super_res_last = self.dummy_super_res_last + + decoder_scheduler = UnCLIPScheduler( + variance_type="learned_range", + prediction_type="epsilon", + num_train_timesteps=1000, + ) + + super_res_scheduler = UnCLIPScheduler( + variance_type="fixed_small_log", + prediction_type="epsilon", + num_train_timesteps=1000, + ) + + feature_extractor = CLIPImageProcessor(crop_size=32, size=32) + + image_encoder = self.dummy_image_encoder + + pipe = UnCLIPImageVariationPipeline( + decoder=decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_proj=text_proj, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + super_res_first=super_res_first, + super_res_last=super_res_last, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + return pipe + + def get_pipeline_inputs(self, device, seed, pil_image=False): + input_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + generator = torch.Generator(device=device).manual_seed(seed) + + if pil_image: + input_image = input_image * 0.5 + 0.5 + input_image = input_image.clamp(0, 1) + input_image = input_image.cpu().permute(0, 2, 3, 1).float().numpy() + input_image = DiffusionPipeline.numpy_to_pil(input_image)[0] + + return { + "image": input_image, + "generator": generator, + "decoder_num_inference_steps": 2, + "super_res_num_inference_steps": 2, + "output_type": "np", + } + + def test_unclip_image_variation_input_tensor(self): + device = "cpu" + seed = 0 + + pipe = self.get_pipeline(device) + + pipeline_inputs = self.get_pipeline_inputs(device, seed) + + output = pipe(**pipeline_inputs) + image = output.images + + tuple_pipeline_inputs = self.get_pipeline_inputs(device, seed) + + image_from_tuple = pipe( + **tuple_pipeline_inputs, + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 128, 128, 3) + + expected_slice = np.array( + [ + 0.9988, + 0.9997, + 0.9944, + 0.0003, + 0.0003, + 0.9974, + 0.0003, + 0.0004, + 0.9931, + ] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_unclip_image_variation_input_image(self): + device = "cpu" + seed = 0 + + pipe = self.get_pipeline(device) + + pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True) + + output = pipe(**pipeline_inputs) + image = output.images + + tuple_pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True) + + image_from_tuple = pipe( + **tuple_pipeline_inputs, + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 128, 128, 3) + + expected_slice = np.array( + [ + 0.9988, + 0.9997, + 0.9944, + 0.0003, + 0.0003, + 0.9974, + 0.0003, + 0.0004, + 0.9931, + ] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_unclip_image_variation_input_list_images(self): + device = "cpu" + seed = 0 + + pipe = self.get_pipeline(device) + + pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True) + pipeline_inputs["image"] = [ + pipeline_inputs["image"], + pipeline_inputs["image"], + ] + + output = pipe(**pipeline_inputs) + image = output.images + + tuple_pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True) + tuple_pipeline_inputs["image"] = [ + tuple_pipeline_inputs["image"], + tuple_pipeline_inputs["image"], + ] + + image_from_tuple = pipe( + **tuple_pipeline_inputs, + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (2, 128, 128, 3) + + expected_slice = np.array( + [ + 0.9997, + 0.9997, + 0.0003, + 0.0003, + 0.9950, + 0.0003, + 0.9993, + 0.9957, + 0.0004, + ] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_unclip_image_variation_input_num_images_per_prompt(self): + device = "cpu" + seed = 0 + + pipe = self.get_pipeline(device) + + pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True) + pipeline_inputs["image"] = [ + pipeline_inputs["image"], + pipeline_inputs["image"], + ] + + output = pipe(**pipeline_inputs, num_images_per_prompt=2) + image = output.images + + tuple_pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True) + tuple_pipeline_inputs["image"] = [ + tuple_pipeline_inputs["image"], + tuple_pipeline_inputs["image"], + ] + + image_from_tuple = pipe( + **tuple_pipeline_inputs, + num_images_per_prompt=2, + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (4, 128, 128, 3) + + expected_slice = np.array( + [ + 0.9997, + 0.9997, + 0.0008, + 0.9952, + 0.9980, + 0.9997, + 0.9961, + 0.9997, + 0.9995, + ] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + +@slow +@require_torch_gpu +class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_unclip_image_variation_karlo(self): + input_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unclip/cat.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/unclip/karlo_v1_alpha_cat_variation_fp16.npy" + ) + + pipeline = UnCLIPImageVariationPipeline.from_pretrained( + "fusing/karlo-image-variations-diffusers", torch_dtype=torch.float16 + ) + pipeline = pipeline.to(torch_device) + pipeline.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipeline( + input_image, + num_images_per_prompt=1, + generator=generator, + output_type="np", + ) + + image = output.images[0] + + np.save("./karlo_v1_alpha_cat_variation_fp16.npy", image) + + assert image.shape == (256, 256, 3) + assert np.abs(expected_image - image).max() < 1e-2