diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 368ea30a2690..bfb9f120b12c 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -230,6 +230,8 @@ title: UnCLIP - local: api/pipelines/latent_diffusion_uncond title: Unconditional Latent Diffusion + - local: api/pipelines/unidiffuser + title: UniDiffuser - local: api/pipelines/versatile_diffusion title: Versatile Diffusion - local: api/pipelines/vq_diffusion diff --git a/docs/source/en/api/diffusion_pipeline.mdx b/docs/source/en/api/diffusion_pipeline.mdx index 280802d6a89a..66e5b7b23bbb 100644 --- a/docs/source/en/api/diffusion_pipeline.mdx +++ b/docs/source/en/api/diffusion_pipeline.mdx @@ -45,3 +45,8 @@ By default diffusion pipelines return an object of class By default diffusion pipelines return an object of class [[autodoc]] pipelines.AudioPipelineOutput + +## ImageTextPipelineOutput +By default diffusion pipelines return an object of class + +[[autodoc]] ImageTextPipelineOutput diff --git a/docs/source/en/api/pipelines/unidiffuser.mdx b/docs/source/en/api/pipelines/unidiffuser.mdx new file mode 100644 index 000000000000..10290e263e6d --- /dev/null +++ b/docs/source/en/api/pipelines/unidiffuser.mdx @@ -0,0 +1,204 @@ + + +# UniDiffuser + +The UniDiffuser model was proposed in [One Transformer Fits All Distributions in Multi-Modal Diffusion at Scale](https://arxiv.org/abs/2303.06555) by Fan Bao, Shen Nie, Kaiwen Xue, Chongxuan Li, Shi Pu, Yaole Wang, Gang Yue, Yue Cao, Hang Su, Jun Zhu. + +The abstract of the [paper](https://arxiv.org/abs/2303.06555) is the following: + +*This paper proposes a unified diffusion framework (dubbed UniDiffuser) to fit all distributions relevant to a set of multi-modal data in one model. Our key insight is -- learning diffusion models for marginal, conditional, and joint distributions can be unified as predicting the noise in the perturbed data, where the perturbation levels (i.e. timesteps) can be different for different modalities. Inspired by the unified view, UniDiffuser learns all distributions simultaneously with a minimal modification to the original diffusion model -- perturbs data in all modalities instead of a single modality, inputs individual timesteps in different modalities, and predicts the noise of all modalities instead of a single modality. UniDiffuser is parameterized by a transformer for diffusion models to handle input types of different modalities. Implemented on large-scale paired image-text data, UniDiffuser is able to perform image, text, text-to-image, image-to-text, and image-text pair generation by setting proper timesteps without additional overhead. In particular, UniDiffuser is able to produce perceptually realistic samples in all tasks and its quantitative results (e.g., the FID and CLIP score) are not only superior to existing general-purpose models but also comparable to the bespoken models (e.g., Stable Diffusion and DALL-E 2) in representative tasks (e.g., text-to-image generation).* + +Resources: + +* [Paper](https://arxiv.org/abs/2303.06555). +* [Original Code](https://github.com/thu-ml/unidiffuser). + +Available Checkpoints are: +- *UniDiffuser-v0 (512x512 resolution)* [thu-ml/unidiffuser-v0](https://huggingface.co/thu-ml/unidiffuser-v0) +- *UniDiffuser-v1 (512x512 resolution)* [thu-ml/unidiffuser-v1](https://huggingface.co/thu-ml/unidiffuser-v1) + +This pipeline was contributed by our community member [dg845](https://github.com/dg845). + +## Available Pipelines: + +| Pipeline | Tasks | Demo | Colab | +|:---:|:---:|:---:|:---:| +| [UniDiffuserPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_unidiffuser.py) | *Joint Image-Text Gen*, *Text-to-Image*, *Image-to-Text*,
*Image Gen*, *Text Gen*, *Image Variation*, *Text Variation* | [🤗 Spaces](https://huggingface.co/spaces/thu-ml/unidiffuser) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/unidiffuser.ipynb) | + +## Usage Examples + +Because the UniDiffuser model is trained to model the joint distribution of (image, text) pairs, it is capable of performing a diverse range of generation tasks. + +### Unconditional Image and Text Generation + +Unconditional generation (where we start from only latents sampled from a standard Gaussian prior) from a [`UniDiffuserPipeline`] will produce a (image, text) pair: + +```python +import torch + +from diffusers import UniDiffuserPipeline + +device = "cuda" +model_id_or_path = "thu-ml/unidiffuser-v1" +pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) +pipe.to(device) + +# Unconditional image and text generation. The generation task is automatically inferred. +sample = pipe(num_inference_steps=20, guidance_scale=8.0) +image = sample.images[0] +text = sample.text[0] +image.save("unidiffuser_joint_sample_image.png") +print(text) +``` + +This is also called "joint" generation in the UniDiffusers paper, since we are sampling from the joint image-text distribution. + +Note that the generation task is inferred from the inputs used when calling the pipeline. +It is also possible to manually specify the unconditional generation task ("mode") manually with [`UniDiffuserPipeline.set_joint_mode`]: + +```python +# Equivalent to the above. +pipe.set_joint_mode() +sample = pipe(num_inference_steps=20, guidance_scale=8.0) +``` + +When the mode is set manually, subsequent calls to the pipeline will use the set mode without attempting the infer the mode. +You can reset the mode with [`UniDiffuserPipeline.reset_mode`], after which the pipeline will once again infer the mode. + +You can also generate only an image or only text (which the UniDiffuser paper calls "marginal" generation since we sample from the marginal distribution of images and text, respectively): + +```python +# Unlike other generation tasks, image-only and text-only generation don't use classifier-free guidance +# Image-only generation +pipe.set_image_mode() +sample_image = pipe(num_inference_steps=20).images[0] +# Text-only generation +pipe.set_text_mode() +sample_text = pipe(num_inference_steps=20).text[0] +``` + +### Text-to-Image Generation + +UniDiffuser is also capable of sampling from conditional distributions; that is, the distribution of images conditioned on a text prompt or the distribution of texts conditioned on an image. +Here is an example of sampling from the conditional image distribution (text-to-image generation or text-conditioned image generation): + +```python +import torch + +from diffusers import UniDiffuserPipeline + +device = "cuda" +model_id_or_path = "thu-ml/unidiffuser-v1" +pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) +pipe.to(device) + +# Text-to-image generation +prompt = "an elephant under the sea" + +sample = pipe(prompt=prompt, num_inference_steps=20, guidance_scale=8.0) +t2i_image = sample.images[0] +t2i_image.save("unidiffuser_text2img_sample_image.png") +``` + +The `text2img` mode requires that either an input `prompt` or `prompt_embeds` be supplied. You can set the `text2img` mode manually with [`UniDiffuserPipeline.set_text_to_image_mode`]. + +### Image-to-Text Generation + +Similarly, UniDiffuser can also produce text samples given an image (image-to-text or image-conditioned text generation): + +```python +import torch + +from diffusers import UniDiffuserPipeline +from diffusers.utils import load_image + +device = "cuda" +model_id_or_path = "thu-ml/unidiffuser-v1" +pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) +pipe.to(device) + +# Image-to-text generation +image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg" +init_image = load_image(image_url).resize((512, 512)) + +sample = pipe(image=init_image, num_inference_steps=20, guidance_scale=8.0) +i2t_text = sample.text[0] +print(i2t_text) +``` + +The `img2text` mode requires that an input `image` be supplied. You can set the `img2text` mode manually with [`UniDiffuserPipeline.set_image_to_text_mode`]. + +### Image Variation + +The UniDiffuser authors suggest performing image variation through a "round-trip" generation method, where given an input image, we first perform an image-to-text generation, and the perform a text-to-image generation on the outputs of the first generation. +This produces a new image which is semantically similar to the input image: + +```python +import torch + +from diffusers import UniDiffuserPipeline +from diffusers.utils import load_image + +device = "cuda" +model_id_or_path = "thu-ml/unidiffuser-v1" +pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) +pipe.to(device) + +# Image variation can be performed with a image-to-text generation followed by a text-to-image generation: +# 1. Image-to-text generation +image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg" +init_image = load_image(image_url).resize((512, 512)) + +sample = pipe(image=init_image, num_inference_steps=20, guidance_scale=8.0) +i2t_text = sample.text[0] +print(i2t_text) + +# 2. Text-to-image generation +sample = pipe(prompt=i2t_text, num_inference_steps=20, guidance_scale=8.0) +final_image = sample.images[0] +final_image.save("unidiffuser_image_variation_sample.png") +``` + +### Text Variation + + +Similarly, text variation can be performed on an input prompt with a text-to-image generation followed by a image-to-text generation: + +```python +import torch + +from diffusers import UniDiffuserPipeline + +device = "cuda" +model_id_or_path = "thu-ml/unidiffuser-v1" +pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) +pipe.to(device) + +# Text variation can be performed with a text-to-image generation followed by a image-to-text generation: +# 1. Text-to-image generation +prompt = "an elephant under the sea" + +sample = pipe(prompt=prompt, num_inference_steps=20, guidance_scale=8.0) +t2i_image = sample.images[0] +t2i_image.save("unidiffuser_text2img_sample_image.png") + +# 2. Image-to-text generation +sample = pipe(image=t2i_image, num_inference_steps=20, guidance_scale=8.0) +final_prompt = sample.text[0] +print(final_prompt) +``` + +## UniDiffuserPipeline +[[autodoc]] UniDiffuserPipeline + - all + - __call__ diff --git a/scripts/convert_unidiffuser_to_diffusers.py b/scripts/convert_unidiffuser_to_diffusers.py new file mode 100644 index 000000000000..891d289d8c76 --- /dev/null +++ b/scripts/convert_unidiffuser_to_diffusers.py @@ -0,0 +1,776 @@ +# Convert the original UniDiffuser checkpoints into diffusers equivalents. + +import argparse +from argparse import Namespace + +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, + GPT2Tokenizer, +) + +from diffusers import ( + AutoencoderKL, + DPMSolverMultistepScheduler, + UniDiffuserModel, + UniDiffuserPipeline, + UniDiffuserTextDecoder, +) + + +SCHEDULER_CONFIG = Namespace( + **{ + "beta_start": 0.00085, + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "solver_order": 3, + } +) + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_resnet_paths +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_attention_paths +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") + + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") + + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") + + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +# Modified from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint +# config.num_head_channels => num_head_channels +def assign_to_checkpoint( + paths, + checkpoint, + old_checkpoint, + attention_paths_to_split=None, + additional_replacements=None, + num_head_channels=1, +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. Assigns the weights to the new + checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // num_head_channels // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_vae_diffusers_config(config_type): + # Hardcoded for now + if args.config_type == "test": + vae_config = create_vae_diffusers_config_test() + elif args.config_type == "big": + vae_config = create_vae_diffusers_config_big() + else: + raise NotImplementedError( + f"Config type {config_type} is not implemented, currently only config types" + " 'test' and 'big' are available." + ) + return vae_config + + +def create_unidiffuser_unet_config(config_type, version): + # Hardcoded for now + if args.config_type == "test": + unet_config = create_unidiffuser_unet_config_test() + elif args.config_type == "big": + unet_config = create_unidiffuser_unet_config_big() + else: + raise NotImplementedError( + f"Config type {config_type} is not implemented, currently only config types" + " 'test' and 'big' are available." + ) + # Unidiffuser-v1 uses data type embeddings + if version == 1: + unet_config["use_data_type_embedding"] = True + return unet_config + + +def create_text_decoder_config(config_type): + # Hardcoded for now + if args.config_type == "test": + text_decoder_config = create_text_decoder_config_test() + elif args.config_type == "big": + text_decoder_config = create_text_decoder_config_big() + else: + raise NotImplementedError( + f"Config type {config_type} is not implemented, currently only config types" + " 'test' and 'big' are available." + ) + return text_decoder_config + + +# Hardcoded configs for test versions of the UniDiffuser models, corresponding to those in the fast default tests. +def create_vae_diffusers_config_test(): + vae_config = { + "sample_size": 32, + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], + "block_out_channels": [32, 64], + "latent_channels": 4, + "layers_per_block": 1, + } + return vae_config + + +def create_unidiffuser_unet_config_test(): + unet_config = { + "text_dim": 32, + "clip_img_dim": 32, + "num_text_tokens": 77, + "num_attention_heads": 2, + "attention_head_dim": 8, + "in_channels": 4, + "out_channels": 4, + "num_layers": 2, + "dropout": 0.0, + "norm_num_groups": 32, + "attention_bias": False, + "sample_size": 16, + "patch_size": 2, + "activation_fn": "gelu", + "num_embeds_ada_norm": 1000, + "norm_type": "layer_norm", + "block_type": "unidiffuser", + "pre_layer_norm": False, + "use_timestep_embedding": False, + "norm_elementwise_affine": True, + "use_patch_pos_embed": False, + "ff_final_dropout": True, + "use_data_type_embedding": False, + } + return unet_config + + +def create_text_decoder_config_test(): + text_decoder_config = { + "prefix_length": 77, + "prefix_inner_dim": 32, + "prefix_hidden_dim": 32, + "vocab_size": 1025, # 1024 + 1 for new EOS token + "n_positions": 1024, + "n_embd": 32, + "n_layer": 5, + "n_head": 4, + "n_inner": 37, + "activation_function": "gelu", + "resid_pdrop": 0.1, + "embd_pdrop": 0.1, + "attn_pdrop": 0.1, + "layer_norm_epsilon": 1e-5, + "initializer_range": 0.02, + } + return text_decoder_config + + +# Hardcoded configs for the UniDiffuser V1 model at https://huggingface.co/thu-ml/unidiffuser-v1 +# See also https://github.com/thu-ml/unidiffuser/blob/main/configs/sample_unidiffuser_v1.py +def create_vae_diffusers_config_big(): + vae_config = { + "sample_size": 256, + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + "block_out_channels": [128, 256, 512, 512], + "latent_channels": 4, + "layers_per_block": 2, + } + return vae_config + + +def create_unidiffuser_unet_config_big(): + unet_config = { + "text_dim": 64, + "clip_img_dim": 512, + "num_text_tokens": 77, + "num_attention_heads": 24, + "attention_head_dim": 64, + "in_channels": 4, + "out_channels": 4, + "num_layers": 30, + "dropout": 0.0, + "norm_num_groups": 32, + "attention_bias": False, + "sample_size": 64, + "patch_size": 2, + "activation_fn": "gelu", + "num_embeds_ada_norm": 1000, + "norm_type": "layer_norm", + "block_type": "unidiffuser", + "pre_layer_norm": False, + "use_timestep_embedding": False, + "norm_elementwise_affine": True, + "use_patch_pos_embed": False, + "ff_final_dropout": True, + "use_data_type_embedding": False, + } + return unet_config + + +# From https://huggingface.co/gpt2/blob/main/config.json, the GPT2 checkpoint used by UniDiffuser +def create_text_decoder_config_big(): + text_decoder_config = { + "prefix_length": 77, + "prefix_inner_dim": 768, + "prefix_hidden_dim": 64, + "vocab_size": 50258, # 50257 + 1 for new EOS token + "n_positions": 1024, + "n_embd": 768, + "n_layer": 12, + "n_head": 12, + "n_inner": 3072, + "activation_function": "gelu", + "resid_pdrop": 0.1, + "embd_pdrop": 0.1, + "attn_pdrop": 0.1, + "layer_norm_epsilon": 1e-5, + "initializer_range": 0.02, + } + return text_decoder_config + + +# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments.convert_ldm_vae_checkpoint +def convert_vae_to_diffusers(ckpt, diffusers_model, num_head_channels=1): + """ + Converts a UniDiffuser autoencoder_kl.pth checkpoint to a diffusers AutoencoderKL. + """ + # autoencoder_kl.pth ckpt is a torch state dict + vae_state_dict = torch.load(ckpt, map_location="cpu") + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + num_head_channels=num_head_channels, # not used in vae + ) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + num_head_channels=num_head_channels, # not used in vae + ) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + num_head_channels=num_head_channels, # not used in vae + ) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + num_head_channels=num_head_channels, # not used in vae + ) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + num_head_channels=num_head_channels, # not used in vae + ) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + num_head_channels=num_head_channels, # not used in vae + ) + conv_attn_to_linear(new_checkpoint) + + missing_keys, unexpected_keys = diffusers_model.load_state_dict(new_checkpoint) + for missing_key in missing_keys: + print(f"Missing key: {missing_key}") + for unexpected_key in unexpected_keys: + print(f"Unexpected key: {unexpected_key}") + + return diffusers_model + + +def convert_uvit_block_to_diffusers_block( + uvit_state_dict, + new_state_dict, + block_prefix, + new_prefix="transformer.transformer_", + skip_connection=False, +): + """ + Maps the keys in a UniDiffuser transformer block (`Block`) to the keys in a diffusers transformer block + (`UTransformerBlock`/`UniDiffuserBlock`). + """ + prefix = new_prefix + block_prefix + if skip_connection: + new_state_dict[prefix + ".skip.skip_linear.weight"] = uvit_state_dict[block_prefix + ".skip_linear.weight"] + new_state_dict[prefix + ".skip.skip_linear.bias"] = uvit_state_dict[block_prefix + ".skip_linear.bias"] + new_state_dict[prefix + ".skip.norm.weight"] = uvit_state_dict[block_prefix + ".norm1.weight"] + new_state_dict[prefix + ".skip.norm.bias"] = uvit_state_dict[block_prefix + ".norm1.bias"] + + # Create the prefix string for out_blocks. + prefix += ".block" + + # Split up attention qkv.weight into to_q.weight, to_k.weight, to_v.weight + qkv = uvit_state_dict[block_prefix + ".attn.qkv.weight"] + new_attn_keys = [".attn1.to_q.weight", ".attn1.to_k.weight", ".attn1.to_v.weight"] + new_attn_keys = [prefix + key for key in new_attn_keys] + shape = qkv.shape[0] // len(new_attn_keys) + for i, attn_key in enumerate(new_attn_keys): + new_state_dict[attn_key] = qkv[i * shape : (i + 1) * shape] + + new_state_dict[prefix + ".attn1.to_out.0.weight"] = uvit_state_dict[block_prefix + ".attn.proj.weight"] + new_state_dict[prefix + ".attn1.to_out.0.bias"] = uvit_state_dict[block_prefix + ".attn.proj.bias"] + new_state_dict[prefix + ".norm1.weight"] = uvit_state_dict[block_prefix + ".norm2.weight"] + new_state_dict[prefix + ".norm1.bias"] = uvit_state_dict[block_prefix + ".norm2.bias"] + new_state_dict[prefix + ".ff.net.0.proj.weight"] = uvit_state_dict[block_prefix + ".mlp.fc1.weight"] + new_state_dict[prefix + ".ff.net.0.proj.bias"] = uvit_state_dict[block_prefix + ".mlp.fc1.bias"] + new_state_dict[prefix + ".ff.net.2.weight"] = uvit_state_dict[block_prefix + ".mlp.fc2.weight"] + new_state_dict[prefix + ".ff.net.2.bias"] = uvit_state_dict[block_prefix + ".mlp.fc2.bias"] + new_state_dict[prefix + ".norm3.weight"] = uvit_state_dict[block_prefix + ".norm3.weight"] + new_state_dict[prefix + ".norm3.bias"] = uvit_state_dict[block_prefix + ".norm3.bias"] + + return uvit_state_dict, new_state_dict + + +def convert_uvit_to_diffusers(ckpt, diffusers_model): + """ + Converts a UniDiffuser uvit_v*.pth checkpoint to a diffusers UniDiffusersModel. + """ + # uvit_v*.pth ckpt is a torch state dict + uvit_state_dict = torch.load(ckpt, map_location="cpu") + + new_state_dict = {} + + # Input layers + new_state_dict["vae_img_in.proj.weight"] = uvit_state_dict["patch_embed.proj.weight"] + new_state_dict["vae_img_in.proj.bias"] = uvit_state_dict["patch_embed.proj.bias"] + new_state_dict["clip_img_in.weight"] = uvit_state_dict["clip_img_embed.weight"] + new_state_dict["clip_img_in.bias"] = uvit_state_dict["clip_img_embed.bias"] + new_state_dict["text_in.weight"] = uvit_state_dict["text_embed.weight"] + new_state_dict["text_in.bias"] = uvit_state_dict["text_embed.bias"] + + new_state_dict["pos_embed"] = uvit_state_dict["pos_embed"] + + # Handle data type token embeddings for UniDiffuser-v1 + if "token_embedding.weight" in uvit_state_dict and diffusers_model.use_data_type_embedding: + new_state_dict["data_type_pos_embed_token"] = uvit_state_dict["pos_embed_token"] + new_state_dict["data_type_token_embedding.weight"] = uvit_state_dict["token_embedding.weight"] + + # Also initialize the PatchEmbedding in UTransformer2DModel with the PatchEmbedding from the checkpoint. + # This isn't used in the current implementation, so might want to remove. + new_state_dict["transformer.pos_embed.proj.weight"] = uvit_state_dict["patch_embed.proj.weight"] + new_state_dict["transformer.pos_embed.proj.bias"] = uvit_state_dict["patch_embed.proj.bias"] + + # Output layers + new_state_dict["transformer.norm_out.weight"] = uvit_state_dict["norm.weight"] + new_state_dict["transformer.norm_out.bias"] = uvit_state_dict["norm.bias"] + + new_state_dict["vae_img_out.weight"] = uvit_state_dict["decoder_pred.weight"] + new_state_dict["vae_img_out.bias"] = uvit_state_dict["decoder_pred.bias"] + new_state_dict["clip_img_out.weight"] = uvit_state_dict["clip_img_out.weight"] + new_state_dict["clip_img_out.bias"] = uvit_state_dict["clip_img_out.bias"] + new_state_dict["text_out.weight"] = uvit_state_dict["text_out.weight"] + new_state_dict["text_out.bias"] = uvit_state_dict["text_out.bias"] + + # in_blocks + in_blocks_prefixes = {".".join(layer.split(".")[:2]) for layer in uvit_state_dict if "in_blocks" in layer} + for in_block_prefix in list(in_blocks_prefixes): + convert_uvit_block_to_diffusers_block(uvit_state_dict, new_state_dict, in_block_prefix) + + # mid_block + # Assume there's only one mid block + convert_uvit_block_to_diffusers_block(uvit_state_dict, new_state_dict, "mid_block") + + # out_blocks + out_blocks_prefixes = {".".join(layer.split(".")[:2]) for layer in uvit_state_dict if "out_blocks" in layer} + for out_block_prefix in list(out_blocks_prefixes): + convert_uvit_block_to_diffusers_block(uvit_state_dict, new_state_dict, out_block_prefix, skip_connection=True) + + missing_keys, unexpected_keys = diffusers_model.load_state_dict(new_state_dict) + for missing_key in missing_keys: + print(f"Missing key: {missing_key}") + for unexpected_key in unexpected_keys: + print(f"Unexpected key: {unexpected_key}") + + return diffusers_model + + +def convert_caption_decoder_to_diffusers(ckpt, diffusers_model): + """ + Converts a UniDiffuser caption_decoder.pth checkpoint to a diffusers UniDiffuserTextDecoder. + """ + # caption_decoder.pth ckpt is a torch state dict + checkpoint_state_dict = torch.load(ckpt, map_location="cpu") + decoder_state_dict = {} + # Remove the "module." prefix, if necessary + caption_decoder_key = "module." + for key in checkpoint_state_dict: + if key.startswith(caption_decoder_key): + decoder_state_dict[key.replace(caption_decoder_key, "")] = checkpoint_state_dict.get(key) + else: + decoder_state_dict[key] = checkpoint_state_dict.get(key) + + new_state_dict = {} + + # Encoder and Decoder + new_state_dict["encode_prefix.weight"] = decoder_state_dict["encode_prefix.weight"] + new_state_dict["encode_prefix.bias"] = decoder_state_dict["encode_prefix.bias"] + new_state_dict["decode_prefix.weight"] = decoder_state_dict["decode_prefix.weight"] + new_state_dict["decode_prefix.bias"] = decoder_state_dict["decode_prefix.bias"] + + # Internal GPT2LMHeadModel transformer model + for key, val in decoder_state_dict.items(): + if key.startswith("gpt"): + suffix = key[len("gpt") :] + new_state_dict["transformer" + suffix] = val + + missing_keys, unexpected_keys = diffusers_model.load_state_dict(new_state_dict) + for missing_key in missing_keys: + print(f"Missing key: {missing_key}") + for unexpected_key in unexpected_keys: + print(f"Unexpected key: {unexpected_key}") + + return diffusers_model + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--caption_decoder_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to caption decoder checkpoint to convert.", + ) + parser.add_argument( + "--uvit_checkpoint_path", default=None, type=str, required=False, help="Path to U-ViT checkpoint to convert." + ) + parser.add_argument( + "--vae_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to VAE checkpoint to convert.", + ) + parser.add_argument( + "--pipeline_output_path", + default=None, + type=str, + required=True, + help="Path to save the output pipeline to.", + ) + parser.add_argument( + "--config_type", + default="test", + type=str, + help=( + "Config type to use. Should be 'test' to create small models for testing or 'big' to convert a full" + " checkpoint." + ), + ) + parser.add_argument( + "--version", + default=0, + type=int, + help="The UniDiffuser model type to convert to. Should be 0 for UniDiffuser-v0 and 1 for UniDiffuser-v1.", + ) + + args = parser.parse_args() + + # Convert the VAE model. + if args.vae_checkpoint_path is not None: + vae_config = create_vae_diffusers_config(args.config_type) + vae = AutoencoderKL(**vae_config) + vae = convert_vae_to_diffusers(args.vae_checkpoint_path, vae) + + # Convert the U-ViT ("unet") model. + if args.uvit_checkpoint_path is not None: + unet_config = create_unidiffuser_unet_config(args.config_type, args.version) + unet = UniDiffuserModel(**unet_config) + unet = convert_uvit_to_diffusers(args.uvit_checkpoint_path, unet) + + # Convert the caption decoder ("text_decoder") model. + if args.caption_decoder_checkpoint_path is not None: + text_decoder_config = create_text_decoder_config(args.config_type) + text_decoder = UniDiffuserTextDecoder(**text_decoder_config) + text_decoder = convert_caption_decoder_to_diffusers(args.caption_decoder_checkpoint_path, text_decoder) + + # Scheduler is the same for both the test and big models. + scheduler_config = SCHEDULER_CONFIG + scheduler = DPMSolverMultistepScheduler( + beta_start=scheduler_config.beta_start, + beta_end=scheduler_config.beta_end, + beta_schedule=scheduler_config.beta_schedule, + solver_order=scheduler_config.solver_order, + ) + + if args.config_type == "test": + # Make a small random CLIPTextModel + torch.manual_seed(0) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(clip_text_encoder_config) + clip_tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # Make a small random CLIPVisionModel and accompanying CLIPImageProcessor + torch.manual_seed(0) + clip_image_encoder_config = CLIPVisionConfig( + image_size=32, + patch_size=2, + num_channels=3, + hidden_size=32, + projection_dim=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + dropout=0.1, + attention_dropout=0.1, + initializer_range=0.02, + ) + image_encoder = CLIPVisionModelWithProjection(clip_image_encoder_config) + image_processor = CLIPImageProcessor(crop_size=32, size=32) + + # Note that the text_decoder should already have its token embeddings resized. + text_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-GPT2Model") + eos = "<|EOS|>" + special_tokens_dict = {"eos_token": eos} + text_tokenizer.add_special_tokens(special_tokens_dict) + elif args.config_type == "big": + text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") + clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + + image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32") + image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32") + + # Note that the text_decoder should already have its token embeddings resized. + text_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + eos = "<|EOS|>" + special_tokens_dict = {"eos_token": eos} + text_tokenizer.add_special_tokens(special_tokens_dict) + else: + raise NotImplementedError( + f"Config type {args.config_type} is not implemented, currently only config types" + " 'test' and 'big' are available." + ) + + pipeline = UniDiffuserPipeline( + vae=vae, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_processor=image_processor, + clip_tokenizer=clip_tokenizer, + text_decoder=text_decoder, + text_tokenizer=text_tokenizer, + unet=unet, + scheduler=scheduler, + ) + pipeline.save_pretrained(args.pipeline_output_path) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f6d8c254d157..402f6eaa749a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -129,6 +129,7 @@ IFInpaintingSuperResolutionPipeline, IFPipeline, IFSuperResolutionPipeline, + ImageTextPipelineOutput, KandinskyImg2ImgPipeline, KandinskyInpaintPipeline, KandinskyPipeline, @@ -161,6 +162,9 @@ TextToVideoZeroPipeline, UnCLIPImageVariationPipeline, UnCLIPPipeline, + UniDiffuserModel, + UniDiffuserPipeline, + UniDiffuserTextDecoder, VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, VersatileDiffusionPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index bb3fc5d04cb6..9e68538f233c 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -89,6 +89,7 @@ from .stable_diffusion_safe import StableDiffusionPipelineSafe from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline + from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder from .versatile_diffusion import ( VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, diff --git a/src/diffusers/pipelines/unidiffuser/__init__.py b/src/diffusers/pipelines/unidiffuser/__init__.py new file mode 100644 index 000000000000..a774e3274030 --- /dev/null +++ b/src/diffusers/pipelines/unidiffuser/__init__.py @@ -0,0 +1,20 @@ +from ...utils import ( + OptionalDependencyNotAvailable, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + ImageTextPipelineOutput, + UniDiffuserPipeline, + ) +else: + from .modeling_text_decoder import UniDiffuserTextDecoder + from .modeling_uvit import UniDiffuserModel, UTransformer2DModel + from .pipeline_unidiffuser import ImageTextPipelineOutput, UniDiffuserPipeline diff --git a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py new file mode 100644 index 000000000000..febc8e09e6ab --- /dev/null +++ b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py @@ -0,0 +1,294 @@ +from typing import Optional + +import numpy as np +import torch +from torch import nn +from transformers import GPT2Config, GPT2LMHeadModel +from transformers.modeling_utils import ModuleUtilsMixin + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin + + +# Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py +class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): + """ + Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to + generate text from the UniDiffuser image-text embedding. + + Parameters: + prefix_length (`int`): + Max number of prefix tokens that will be supplied to the model. + prefix_inner_dim (`int`): + The hidden size of the the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the + CLIP text encoder. + prefix_hidden_dim (`int`, *optional*): + Hidden dim of the MLP if we encode the prefix. + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*, defaults to None): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): + Whether to additionally scale attention weights by `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention + dot-product/softmax to float() when training with mixed precision. + """ + + @register_to_config + def __init__( + self, + prefix_length: int, + prefix_inner_dim: int, + prefix_hidden_dim: Optional[int] = None, + vocab_size: int = 50257, # Start of GPT2 config args + n_positions: int = 1024, + n_embd: int = 768, + n_layer: int = 12, + n_head: int = 12, + n_inner: Optional[int] = None, + activation_function: str = "gelu_new", + resid_pdrop: float = 0.1, + embd_pdrop: float = 0.1, + attn_pdrop: float = 0.1, + layer_norm_epsilon: float = 1e-5, + initializer_range: float = 0.02, + scale_attn_weights: bool = True, + use_cache: bool = True, + scale_attn_by_inverse_layer_idx: bool = False, + reorder_and_upcast_attn: bool = False, + ): + super().__init__() + + self.prefix_length = prefix_length + + if prefix_inner_dim != n_embd and prefix_hidden_dim is None: + raise ValueError( + f"`prefix_hidden_dim` cannot be `None` when `prefix_inner_dim`: {prefix_hidden_dim} and" + f" `n_embd`: {n_embd} are not equal." + ) + + self.prefix_inner_dim = prefix_inner_dim + self.prefix_hidden_dim = prefix_hidden_dim + + self.encode_prefix = ( + nn.Linear(self.prefix_inner_dim, self.prefix_hidden_dim) + if self.prefix_hidden_dim is not None + else nn.Identity() + ) + self.decode_prefix = ( + nn.Linear(self.prefix_hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity() + ) + + gpt_config = GPT2Config( + vocab_size=vocab_size, + n_positions=n_positions, + n_embd=n_embd, + n_layer=n_layer, + n_head=n_head, + n_inner=n_inner, + activation_function=activation_function, + resid_pdrop=resid_pdrop, + embd_pdrop=embd_pdrop, + attn_pdrop=attn_pdrop, + layer_norm_epsilon=layer_norm_epsilon, + initializer_range=initializer_range, + scale_attn_weights=scale_attn_weights, + use_cache=use_cache, + scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, + reorder_and_upcast_attn=reorder_and_upcast_attn, + ) + self.transformer = GPT2LMHeadModel(gpt_config) + + def forward( + self, + input_ids: torch.Tensor, + prefix_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + ): + """ + Args: + input_ids (`torch.Tensor` of shape `(N, max_seq_len)`): + Text tokens to use for inference. + prefix_embeds (`torch.Tensor` of shape `(N, prefix_length, 768)`): + Prefix embedding to preprend to the embedded tokens. + attention_mask (`torch.Tensor` of shape `(N, prefix_length + max_seq_len, 768)`, *optional*): + Attention mask for the prefix embedding. + labels (`torch.Tensor`, *optional*): + Labels to use for language modeling. + """ + embedding_text = self.transformer.transformer.wte(input_ids) + hidden = self.encode_prefix(prefix_embeds) + prefix_embeds = self.decode_prefix(hidden) + embedding_cat = torch.cat((prefix_embeds, embedding_text), dim=1) + + if labels is not None: + dummy_token = self.get_dummy_token(input_ids.shape[0], input_ids.device) + labels = torch.cat((dummy_token, input_ids), dim=1) + out = self.transformer(inputs_embeds=embedding_cat, labels=labels, attention_mask=attention_mask) + if self.prefix_hidden_dim is not None: + return out, hidden + else: + return out + + def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor: + return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device) + + def encode(self, prefix): + return self.encode_prefix(prefix) + + @torch.no_grad() + def generate_captions(self, features, eos_token_id, device): + """ + Generate captions given text embedding features. Returns list[L]. + + Args: + features (`torch.Tensor` of shape `(B, L, D)`): + Text embedding features to generate captions from. + eos_token_id (`int`): + The token ID of the EOS token for the text decoder model. + device: + Device to perform text generation on. + + Returns: + `List[str]`: A list of strings generated from the decoder model. + """ + + features = torch.split(features, 1, dim=0) + generated_tokens = [] + generated_seq_lengths = [] + for feature in features: + feature = self.decode_prefix(feature.to(device)) # back to the clip feature + # Only support beam search for now + output_tokens, seq_lengths = self.generate_beam( + input_embeds=feature, device=device, eos_token_id=eos_token_id + ) + generated_tokens.append(output_tokens[0]) + generated_seq_lengths.append(seq_lengths[0]) + generated_tokens = torch.stack(generated_tokens) + generated_seq_lengths = torch.stack(generated_seq_lengths) + return generated_tokens, generated_seq_lengths + + @torch.no_grad() + def generate_beam( + self, + input_ids=None, + input_embeds=None, + device=None, + beam_size: int = 5, + entry_length: int = 67, + temperature: float = 1.0, + eos_token_id: Optional[int] = None, + ): + """ + Generates text using the given tokenizer and text prompt or token embedding via beam search. This + implementation is based on the beam search implementation from the [original UniDiffuser + code](https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py#L89). + + Args: + eos_token_id (`int`, *optional*): + The token ID of the EOS token for the text decoder model. + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Tokenizer indices of input sequence tokens in the vocabulary. One of `input_ids` and `input_embeds` + must be supplied. + input_embeds (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + An embedded representation to directly pass to the transformer as a prefix for beam search. One of + `input_ids` and `input_embeds` must be supplied. + device: + The device to perform beam search on. + beam_size (`int`, *optional*, defaults to `5`): + The number of best states to store during beam search. + entry_length (`int`, *optional*, defaults to `67`): + The number of iterations to run beam search. + temperature (`float`, *optional*, defaults to 1.0): + The temperature to use when performing the softmax over logits from the decoding model. + + Returns: + `Tuple(torch.Tensor, torch.Tensor)`: A tuple of tensors where the first element is a tensor of generated + token sequences sorted by score in descending order, and the second element is the sequence lengths + corresponding to those sequences. + """ + # Generates text until stop_token is reached using beam search with the desired beam size. + stop_token_index = eos_token_id + tokens = None + scores = None + seq_lengths = torch.ones(beam_size, device=device, dtype=torch.int) + is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool) + + if input_embeds is not None: + generated = input_embeds + else: + generated = self.transformer.transformer.wte(input_ids) + + for i in range(entry_length): + outputs = self.transformer(inputs_embeds=generated) + logits = outputs.logits + logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) + logits = logits.softmax(-1).log() + + if scores is None: + scores, next_tokens = logits.topk(beam_size, -1) + generated = generated.expand(beam_size, *generated.shape[1:]) + next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) + if tokens is None: + tokens = next_tokens + else: + tokens = tokens.expand(beam_size, *tokens.shape[1:]) + tokens = torch.cat((tokens, next_tokens), dim=1) + else: + logits[is_stopped] = -float(np.inf) + logits[is_stopped, 0] = 0 + scores_sum = scores[:, None] + logits + seq_lengths[~is_stopped] += 1 + scores_sum_average = scores_sum / seq_lengths[:, None] + scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1) + next_tokens_source = next_tokens // scores_sum.shape[1] + seq_lengths = seq_lengths[next_tokens_source] + next_tokens = next_tokens % scores_sum.shape[1] + next_tokens = next_tokens.unsqueeze(1) + tokens = tokens[next_tokens_source] + tokens = torch.cat((tokens, next_tokens), dim=1) + generated = generated[next_tokens_source] + scores = scores_sum_average * seq_lengths + is_stopped = is_stopped[next_tokens_source] + + next_token_embed = self.transformer.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1) + generated = torch.cat((generated, next_token_embed), dim=1) + is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze() + if is_stopped.all(): + break + + scores = scores / seq_lengths + order = scores.argsort(descending=True) + # tokens tensors are already padded to max_seq_length + output_texts = [tokens[i] for i in order] + output_texts = torch.stack(output_texts, dim=0) + seq_lengths = torch.tensor([seq_lengths[i] for i in order], dtype=seq_lengths.dtype) + return output_texts, seq_lengths diff --git a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py new file mode 100644 index 000000000000..b7829f76ec12 --- /dev/null +++ b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py @@ -0,0 +1,1196 @@ +import math +from typing import Optional, Union + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin +from ...models.attention import AdaLayerNorm, FeedForward +from ...models.attention_processor import Attention +from ...models.embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed +from ...models.transformer_2d import Transformer2DModelOutput +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + logger.warning( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect." + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (torch.Tensor, float, float, float, float) -> torch.Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, + \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for + generating the random values works best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + use_pos_embed=True, + ): + super().__init__() + + num_patches = (height // patch_size) * (width // patch_size) + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.use_pos_embed = use_pos_embed + if self.use_pos_embed: + pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + + def forward(self, latent): + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.layer_norm: + latent = self.norm(latent) + if self.use_pos_embed: + return latent + self.pos_embed + else: + return latent + + +class SkipBlock(nn.Module): + def __init__(self, dim: int): + super().__init__() + + self.skip_linear = nn.Linear(2 * dim, dim) + + # Use torch.nn.LayerNorm for now, following the original code + self.norm = nn.LayerNorm(dim) + + def forward(self, x, skip): + x = self.skip_linear(torch.cat([x, skip], dim=-1)) + x = self.norm(x) + + return x + + +# Modified to support both pre-LayerNorm and post-LayerNorm configurations +# Don't support AdaLayerNormZero for now +# Modified from diffusers.models.attention.BasicTransformerBlock +class UTransformerBlock(nn.Module): + r""" + A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. + num_embeds_ada_norm (:obj: `int`, *optional*): + The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (:obj: `bool`, *optional*, defaults to `False`): + Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float32 when performing the attention calculation. + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + norm_type (`str`, defaults to `"layer_norm"`): + The layer norm implementation to use. + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). Note that `BasicTransformerBlock` uses pre-LayerNorm, e.g. + `pre_layer_norm = True`. + final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + pre_layer_norm: bool = True, + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + self.pre_layer_norm = pre_layer_norm + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # 1. Self-Attn + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.attn2 = None + + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + else: + self.norm2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + cross_attention_kwargs=None, + class_labels=None, + ): + # Pre-LayerNorm + if self.pre_layer_norm: + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + else: + norm_hidden_states = self.norm1(hidden_states) + else: + norm_hidden_states = hidden_states + + # 1. Self-Attention + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + # Post-LayerNorm + if not self.pre_layer_norm: + if self.use_ada_layer_norm: + attn_output = self.norm1(attn_output, timestep) + else: + attn_output = self.norm1(attn_output) + + hidden_states = attn_output + hidden_states + + if self.attn2 is not None: + # Pre-LayerNorm + if self.pre_layer_norm: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + else: + norm_hidden_states = hidden_states + # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly + # prepare attention mask here + + # 2. Cross-Attention + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + + # Post-LayerNorm + if not self.pre_layer_norm: + attn_output = self.norm2(attn_output, timestep) if self.use_ada_layer_norm else self.norm2(attn_output) + + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + # Pre-LayerNorm + if self.pre_layer_norm: + norm_hidden_states = self.norm3(hidden_states) + else: + norm_hidden_states = hidden_states + + ff_output = self.ff(norm_hidden_states) + + # Post-LayerNorm + if not self.pre_layer_norm: + ff_output = self.norm3(ff_output) + + hidden_states = ff_output + hidden_states + + return hidden_states + + +# Like UTransformerBlock except with LayerNorms on the residual backbone of the block +# Modified from diffusers.models.attention.BasicTransformerBlock +class UniDiffuserBlock(nn.Module): + r""" + A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations and puts the + LayerNorms on the residual backbone of the block. This matches the transformer block in the [original UniDiffuser + implementation](https://github.com/thu-ml/unidiffuser/blob/main/libs/uvit_multi_post_ln_v1.py#L104). + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. + num_embeds_ada_norm (:obj: `int`, *optional*): + The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (:obj: `bool`, *optional*, defaults to `False`): + Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float() when performing the attention calculation. + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + norm_type (`str`, defaults to `"layer_norm"`): + The layer norm implementation to use. + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm + (`pre_layer_norm = False`). + final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + pre_layer_norm: bool = False, + final_dropout: bool = True, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + self.pre_layer_norm = pre_layer_norm + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # 1. Self-Attn + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.attn2 = None + + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + else: + self.norm2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + cross_attention_kwargs=None, + class_labels=None, + ): + # Following the diffusers transformer block implementation, put the LayerNorm on the + # residual backbone + # Pre-LayerNorm + if self.pre_layer_norm: + if self.use_ada_layer_norm: + hidden_states = self.norm1(hidden_states, timestep) + else: + hidden_states = self.norm1(hidden_states) + + # 1. Self-Attention + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + attn_output = self.attn1( + hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + # Following the diffusers transformer block implementation, put the LayerNorm on the + # residual backbone + # Post-LayerNorm + if not self.pre_layer_norm: + if self.use_ada_layer_norm: + hidden_states = self.norm1(hidden_states, timestep) + else: + hidden_states = self.norm1(hidden_states) + + if self.attn2 is not None: + # Pre-LayerNorm + if self.pre_layer_norm: + hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly + # prepare attention mask here + + # 2. Cross-Attention + attn_output = self.attn2( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + # Post-LayerNorm + if not self.pre_layer_norm: + hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + # 3. Feed-forward + # Pre-LayerNorm + if self.pre_layer_norm: + hidden_states = self.norm3(hidden_states) + + ff_output = self.ff(hidden_states) + + hidden_states = ff_output + hidden_states + + # Post-LayerNorm + if not self.pre_layer_norm: + hidden_states = self.norm3(hidden_states) + + return hidden_states + + +# Modified from diffusers.models.transformer_2d.Transformer2DModel +# Modify the transformer block structure to be U-Net like following U-ViT +# Only supports patch-style input and torch.nn.LayerNorm currently +# https://github.com/baofff/U-ViT +class UTransformer2DModel(ModelMixin, ConfigMixin): + """ + Transformer model based on the [U-ViT](https://github.com/baofff/U-ViT) architecture for image-like data. Compared + to [`Transformer2DModel`], this model has skip connections between transformer blocks in a "U"-shaped fashion, + similar to a U-Net. Supports only continuous (actual embeddings) inputs, which are embedded via a [`PatchEmbed`] + layer and then reshaped to (b, t, d). + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input. + out_channels (`int`, *optional*): + The number of output channels; if `None`, defaults to `in_channels`. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + norm_num_groups (`int`, *optional*, defaults to `32`): + The number of groups to use when performing Group Normalization. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + patch_size (`int`, *optional*, defaults to 2): + The patch size to use in the patch embedding. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + use_linear_projection (int, *optional*): TODO: Not used + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used in each + transformer block. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float() when performing the attention calculation. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`. + block_type (`str`, *optional*, defaults to `"unidiffuser"`): + The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual + backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard + behavior in `diffusers`.) + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm + (`pre_layer_norm = False`). + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + use_patch_pos_embed (`bool`, *optional*): + Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`). + final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = 2, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + block_type: str = "unidiffuser", + pre_layer_norm: bool = False, + norm_elementwise_affine: bool = True, + use_patch_pos_embed=False, + ff_final_dropout: bool = False, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Input + # Only support patch input of shape (batch_size, num_channels, height, width) for now + assert in_channels is not None and patch_size is not None, "Patch input requires in_channels and patch_size." + + assert sample_size is not None, "UTransformer2DModel over patched input must provide sample_size" + + # 2. Define input layers + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + use_pos_embed=use_patch_pos_embed, + ) + + # 3. Define transformers blocks + # Modify this to have in_blocks ("downsample" blocks, even though we don't actually downsample), a mid_block, + # and out_blocks ("upsample" blocks). Like a U-Net, there are skip connections from in_blocks to out_blocks in + # a "U"-shaped fashion (e.g. first in_block to last out_block, etc.). + # Quick hack to make the transformer block type configurable + if block_type == "unidiffuser": + block_cls = UniDiffuserBlock + else: + block_cls = UTransformerBlock + self.transformer_in_blocks = nn.ModuleList( + [ + block_cls( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + pre_layer_norm=pre_layer_norm, + norm_elementwise_affine=norm_elementwise_affine, + final_dropout=ff_final_dropout, + ) + for d in range(num_layers // 2) + ] + ) + + self.transformer_mid_block = block_cls( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + pre_layer_norm=pre_layer_norm, + norm_elementwise_affine=norm_elementwise_affine, + final_dropout=ff_final_dropout, + ) + + # For each skip connection, we use a SkipBlock (concatenation + Linear + LayerNorm) to process the inputs + # before each transformer out_block. + self.transformer_out_blocks = nn.ModuleList( + [ + nn.ModuleDict( + { + "skip": SkipBlock( + inner_dim, + ), + "block": block_cls( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + pre_layer_norm=pre_layer_norm, + norm_elementwise_affine=norm_elementwise_affine, + final_dropout=ff_final_dropout, + ), + } + ) + for d in range(num_layers // 2) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + + # Following the UniDiffuser U-ViT implementation, we process the transformer output with + # a LayerNorm layer with per-element affine params + self.norm_out = nn.LayerNorm(inner_dim) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + return_dict: bool = True, + hidden_states_is_embedding: bool = False, + unpatchify: bool = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. + cross_attention_kwargs (*optional*): + Keyword arguments to supply to the cross attention layers, if used. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + hidden_states_is_embedding (`bool`, *optional*, defaults to `False`): + Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will + ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the + transformer blocks. + unpatchify (`bool`, *optional*, defaults to `True`): + Whether to unpatchify the transformer output. + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # 0. Check inputs + + if not unpatchify and return_dict: + raise ValueError( + f"Cannot both define `unpatchify`: {unpatchify} and `return_dict`: {return_dict} since when" + f" `unpatchify` is {unpatchify} the returned output is of shape (batch_size, seq_len, hidden_dim)" + " rather than (batch_size, num_channels, height, width)." + ) + + # 1. Input + if not hidden_states_is_embedding: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + + # In ("downsample") blocks + skips = [] + for in_block in self.transformer_in_blocks: + hidden_states = in_block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + skips.append(hidden_states) + + # Mid block + hidden_states = self.transformer_mid_block(hidden_states) + + # Out ("upsample") blocks + for out_block in self.transformer_out_blocks: + hidden_states = out_block["skip"](hidden_states, skips.pop()) + hidden_states = out_block["block"]( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + # Don't support AdaLayerNorm for now, so no conditioning/scale/shift logic + hidden_states = self.norm_out(hidden_states) + # hidden_states = self.proj_out(hidden_states) + + if unpatchify: + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + else: + output = hidden_states + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + +class UniDiffuserModel(ModelMixin, ConfigMixin): + """ + Transformer model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is a + modification of [`UTransformer2DModel`] with input and output heads for the VAE-embedded latent image, the + CLIP-embedded image, and the CLIP-embedded prompt (see paper for more details). + + Parameters: + text_dim (`int`): The hidden dimension of the CLIP text model used to embed images. + clip_img_dim (`int`): The hidden dimension of the CLIP vision model used to embed prompts. + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input. + out_channels (`int`, *optional*): + The number of output channels; if `None`, defaults to `in_channels`. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + norm_num_groups (`int`, *optional*, defaults to `32`): + The number of groups to use when performing Group Normalization. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + patch_size (`int`, *optional*, defaults to 2): + The patch size to use in the patch embedding. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + use_linear_projection (int, *optional*): TODO: Not used + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used in each + transformer block. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float32 when performing the attention calculation. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`. + block_type (`str`, *optional*, defaults to `"unidiffuser"`): + The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual + backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard + behavior in `diffusers`.) + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm + (`pre_layer_norm = False`). + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + use_patch_pos_embed (`bool`, *optional*): + Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`). + ff_final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + use_data_type_embedding (`bool`, *optional*): + Whether to use a data type embedding. This is only relevant for UniDiffuser-v1 style models; UniDiffuser-v1 + is continue-trained from UniDiffuser-v0 on non-publically-available data and accepts a `data_type` + argument, which can either be `1` to use the weights trained on non-publically-available data or `0` + otherwise. This argument is subsequently embedded by the data type embedding, if used. + """ + + @register_to_config + def __init__( + self, + text_dim: int = 768, + clip_img_dim: int = 512, + num_text_tokens: int = 77, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + block_type: str = "unidiffuser", + pre_layer_norm: bool = False, + use_timestep_embedding=False, + norm_elementwise_affine: bool = True, + use_patch_pos_embed=False, + ff_final_dropout: bool = True, + use_data_type_embedding: bool = False, + ): + super().__init__() + + # 0. Handle dimensions + self.inner_dim = num_attention_heads * attention_head_dim + + assert sample_size is not None, "UniDiffuserModel over patched input must provide sample_size" + self.sample_size = sample_size + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.patch_size = patch_size + # Assume image is square... + self.num_patches = (self.sample_size // patch_size) * (self.sample_size // patch_size) + + # 1. Define input layers + # 1.1 Input layers for text and image input + # For now, only support patch input for VAE latent image input + self.vae_img_in = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=self.inner_dim, + use_pos_embed=use_patch_pos_embed, + ) + self.clip_img_in = nn.Linear(clip_img_dim, self.inner_dim) + self.text_in = nn.Linear(text_dim, self.inner_dim) + + # 1.2. Timestep embeddings for t_img, t_text + self.timestep_img_proj = Timesteps( + self.inner_dim, + flip_sin_to_cos=True, + downscale_freq_shift=0, + ) + self.timestep_img_embed = ( + TimestepEmbedding( + self.inner_dim, + 4 * self.inner_dim, + out_dim=self.inner_dim, + ) + if use_timestep_embedding + else nn.Identity() + ) + + self.timestep_text_proj = Timesteps( + self.inner_dim, + flip_sin_to_cos=True, + downscale_freq_shift=0, + ) + self.timestep_text_embed = ( + TimestepEmbedding( + self.inner_dim, + 4 * self.inner_dim, + out_dim=self.inner_dim, + ) + if use_timestep_embedding + else nn.Identity() + ) + + # 1.3. Positional embedding + self.num_text_tokens = num_text_tokens + self.num_tokens = 1 + 1 + num_text_tokens + 1 + self.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, self.inner_dim)) + self.pos_embed_drop = nn.Dropout(p=dropout) + trunc_normal_(self.pos_embed, std=0.02) + + # 1.4. Handle data type token embeddings for UniDiffuser-V1, if necessary + self.use_data_type_embedding = use_data_type_embedding + if self.use_data_type_embedding: + self.data_type_token_embedding = nn.Embedding(2, self.inner_dim) + self.data_type_pos_embed_token = nn.Parameter(torch.zeros(1, 1, self.inner_dim)) + + # 2. Define transformer blocks + self.transformer = UTransformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + out_channels=out_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + patch_size=patch_size, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + block_type=block_type, + pre_layer_norm=pre_layer_norm, + norm_elementwise_affine=norm_elementwise_affine, + use_patch_pos_embed=use_patch_pos_embed, + ff_final_dropout=ff_final_dropout, + ) + + # 3. Define output layers + patch_dim = (patch_size**2) * out_channels + self.vae_img_out = nn.Linear(self.inner_dim, patch_dim) + self.clip_img_out = nn.Linear(self.inner_dim, clip_img_dim) + self.text_out = nn.Linear(self.inner_dim, text_dim) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed"} + + def forward( + self, + latent_image_embeds: torch.FloatTensor, + image_embeds: torch.FloatTensor, + prompt_embeds: torch.FloatTensor, + timestep_img: Union[torch.Tensor, float, int], + timestep_text: Union[torch.Tensor, float, int], + data_type: Optional[Union[torch.Tensor, float, int]] = 1, + encoder_hidden_states=None, + cross_attention_kwargs=None, + ): + """ + Args: + latent_image_embeds (`torch.FloatTensor` of shape `(batch size, latent channels, height, width)`): + Latent image representation from the VAE encoder. + image_embeds (`torch.FloatTensor` of shape `(batch size, 1, clip_img_dim)`): + CLIP-embedded image representation (unsqueezed in the first dimension). + prompt_embeds (`torch.FloatTensor` of shape `(batch size, seq_len, text_dim)`): + CLIP-embedded text representation. + timestep_img (`torch.long` or `float` or `int`): + Current denoising step for the image. + timestep_text (`torch.long` or `float` or `int`): + Current denoising step for the text. + data_type: (`torch.int` or `float` or `int`, *optional*, defaults to `1`): + Only used in UniDiffuser-v1-style models. Can be either `1`, to use weights trained on nonpublic data, + or `0` otherwise. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + cross_attention_kwargs (*optional*): + Keyword arguments to supply to the cross attention layers, if used. + + + Returns: + `tuple`: Returns relevant parts of the model's noise prediction: the first element of the tuple is tbe VAE + image embedding, the second element is the CLIP image embedding, and the third element is the CLIP text + embedding. + """ + batch_size = latent_image_embeds.shape[0] + + # 1. Input + # 1.1. Map inputs to shape (B, N, inner_dim) + vae_hidden_states = self.vae_img_in(latent_image_embeds) + clip_hidden_states = self.clip_img_in(image_embeds) + text_hidden_states = self.text_in(prompt_embeds) + + num_text_tokens, num_img_tokens = text_hidden_states.size(1), vae_hidden_states.size(1) + + # 1.2. Encode image timesteps to single token (B, 1, inner_dim) + if not torch.is_tensor(timestep_img): + timestep_img = torch.tensor([timestep_img], dtype=torch.long, device=vae_hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep_img = timestep_img * torch.ones(batch_size, dtype=timestep_img.dtype, device=timestep_img.device) + + timestep_img_token = self.timestep_img_proj(timestep_img) + # t_img_token does not contain any weights and will always return f32 tensors + # but time_embedding might be fp16, so we need to cast here. + timestep_img_token = timestep_img_token.to(dtype=self.dtype) + timestep_img_token = self.timestep_img_embed(timestep_img_token) + timestep_img_token = timestep_img_token.unsqueeze(dim=1) + + # 1.3. Encode text timesteps to single token (B, 1, inner_dim) + if not torch.is_tensor(timestep_text): + timestep_text = torch.tensor([timestep_text], dtype=torch.long, device=vae_hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep_text = timestep_text * torch.ones(batch_size, dtype=timestep_text.dtype, device=timestep_text.device) + + timestep_text_token = self.timestep_text_proj(timestep_text) + # t_text_token does not contain any weights and will always return f32 tensors + # but time_embedding might be fp16, so we need to cast here. + timestep_text_token = timestep_text_token.to(dtype=self.dtype) + timestep_text_token = self.timestep_text_embed(timestep_text_token) + timestep_text_token = timestep_text_token.unsqueeze(dim=1) + + # 1.4. Concatenate all of the embeddings together. + if self.use_data_type_embedding: + assert data_type is not None, "data_type must be supplied if the model uses a data type embedding" + if not torch.is_tensor(data_type): + data_type = torch.tensor([data_type], dtype=torch.int, device=vae_hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + data_type = data_type * torch.ones(batch_size, dtype=data_type.dtype, device=data_type.device) + + data_type_token = self.data_type_token_embedding(data_type).unsqueeze(dim=1) + hidden_states = torch.cat( + [ + timestep_img_token, + timestep_text_token, + data_type_token, + text_hidden_states, + clip_hidden_states, + vae_hidden_states, + ], + dim=1, + ) + else: + hidden_states = torch.cat( + [timestep_img_token, timestep_text_token, text_hidden_states, clip_hidden_states, vae_hidden_states], + dim=1, + ) + + # 1.5. Prepare the positional embeddings and add to hidden states + # Note: I think img_vae should always have the proper shape, so there's no need to interpolate + # the position embeddings. + if self.use_data_type_embedding: + pos_embed = torch.cat( + [self.pos_embed[:, : 1 + 1, :], self.data_type_pos_embed_token, self.pos_embed[:, 1 + 1 :, :]], dim=1 + ) + else: + pos_embed = self.pos_embed + hidden_states = hidden_states + pos_embed + hidden_states = self.pos_embed_drop(hidden_states) + + # 2. Blocks + hidden_states = self.transformer( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=None, + class_labels=None, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + hidden_states_is_embedding=True, + unpatchify=False, + )[0] + + # 3. Output + # Split out the predicted noise representation. + if self.use_data_type_embedding: + ( + t_img_token_out, + t_text_token_out, + data_type_token_out, + text_out, + img_clip_out, + img_vae_out, + ) = hidden_states.split((1, 1, 1, num_text_tokens, 1, num_img_tokens), dim=1) + else: + t_img_token_out, t_text_token_out, text_out, img_clip_out, img_vae_out = hidden_states.split( + (1, 1, num_text_tokens, 1, num_img_tokens), dim=1 + ) + + img_vae_out = self.vae_img_out(img_vae_out) + + # unpatchify + height = width = int(img_vae_out.shape[1] ** 0.5) + img_vae_out = img_vae_out.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + img_vae_out = torch.einsum("nhwpqc->nchpwq", img_vae_out) + img_vae_out = img_vae_out.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + img_clip_out = self.clip_img_out(img_clip_out) + + text_out = self.text_out(text_out) + + return img_vae_out, img_clip_out, text_out diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py new file mode 100644 index 000000000000..36e5411b4215 --- /dev/null +++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py @@ -0,0 +1,1422 @@ +import inspect +from dataclasses import dataclass +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + GPT2Tokenizer, +) + +from ...models import AutoencoderKL +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, +) +from ...utils.outputs import BaseOutput +from ..pipeline_utils import DiffusionPipeline +from .modeling_text_decoder import UniDiffuserTextDecoder +from .modeling_uvit import UniDiffuserModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +# New BaseOutput child class for joint image-text output +@dataclass +class ImageTextPipelineOutput(BaseOutput): + """ + Output class for joint image-text pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + text (`List[str]` or `List[List[str]]`) + List of generated text strings of length `batch_size` or a list of list of strings whose outer list has + length `batch_size`. Text generated by the diffusion pipeline. + """ + + images: Optional[Union[List[PIL.Image.Image], np.ndarray]] + text: Optional[Union[List[str], List[List[str]]]] + + +class UniDiffuserPipeline(DiffusionPipeline): + r""" + Pipeline for a bimodal image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model, which supports + unconditional text and image generation, text-conditioned image generation, image-conditioned text generation, and + joint image-text generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. This + is part of the UniDiffuser image representation, along with the CLIP vision encoding. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Similar to Stable Diffusion, UniDiffuser uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to encode text + prompts. + image_encoder ([`CLIPVisionModel`]): + UniDiffuser uses the vision portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel) to encode + images as part of its image representation, along with the VAE latent representation. + image_processor ([`CLIPImageProcessor`]): + CLIP image processor of class + [CLIPImageProcessor](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor), + used to preprocess the image before CLIP encoding it with `image_encoder`. + clip_tokenizer ([`CLIPTokenizer`]): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTokenizer) which + is used to tokenizer a prompt before encoding it with `text_encoder`. + text_decoder ([`UniDiffuserTextDecoder`]): + Frozen text decoder. This is a GPT-style model which is used to generate text from the UniDiffuser + embedding. + text_tokenizer ([`GPT2Tokenizer`]): + Tokenizer of class + [GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) which + is used along with the `text_decoder` to decode text for text generation. + unet ([`UniDiffuserModel`]): + UniDiffuser uses a [U-ViT](https://github.com/baofff/U-ViT) model architecture, which is similar to a + [`Transformer2DModel`] with U-Net-style skip connections between transformer layers. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image and/or text latents. The + original UniDiffuser paper uses the [`DPMSolverMultistepScheduler`] scheduler. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + image_encoder: CLIPVisionModelWithProjection, + image_processor: CLIPImageProcessor, + clip_tokenizer: CLIPTokenizer, + text_decoder: UniDiffuserTextDecoder, + text_tokenizer: GPT2Tokenizer, + unet: UniDiffuserModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + if text_encoder.config.hidden_size != text_decoder.prefix_inner_dim: + raise ValueError( + f"The text encoder hidden size and text decoder prefix inner dim must be the same, but" + f" `text_encoder.config.hidden_size`: {text_encoder.config.hidden_size} and `text_decoder.prefix_inner_dim`: {text_decoder.prefix_inner_dim}" + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_processor=image_processor, + clip_tokenizer=clip_tokenizer, + text_decoder=text_decoder, + text_tokenizer=text_tokenizer, + unet=unet, + scheduler=scheduler, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + self.num_channels_latents = vae.config.latent_channels + self.text_encoder_seq_len = text_encoder.config.max_position_embeddings + self.text_encoder_hidden_size = text_encoder.config.hidden_size + self.image_encoder_projection_dim = image_encoder.config.projection_dim + self.unet_resolution = unet.config.sample_size + + self.text_intermediate_dim = self.text_encoder_hidden_size + if self.text_decoder.prefix_hidden_dim is not None: + self.text_intermediate_dim = self.text_decoder.prefix_hidden_dim + + self.mode = None + + # TODO: handle safety checking? + self.safety_checker = None + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload + # Add self.image_encoder, self.text_decoder to cpu_offloaded_models list + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta')` and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.image_encoder, self.text_decoder]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + # Add self.image_encoder, self.text_decoder to cpu_offloaded_models list + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae, self.image_encoder, self.text_decoder]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def _infer_mode(self, prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents): + r""" + Infer the generation task ('mode') from the inputs to `__call__`. If the mode has been manually set, the set + mode will be used. + """ + prompt_available = (prompt is not None) or (prompt_embeds is not None) + image_available = image is not None + input_available = prompt_available or image_available + + prompt_latents_available = prompt_latents is not None + vae_latents_available = vae_latents is not None + clip_latents_available = clip_latents is not None + full_latents_available = latents is not None + image_latents_available = vae_latents_available and clip_latents_available + all_indv_latents_available = prompt_latents_available and image_latents_available + + if self.mode is not None: + # Preferentially use the mode set by the user + mode = self.mode + elif prompt_available: + mode = "text2img" + elif image_available: + mode = "img2text" + else: + # Neither prompt nor image supplied, infer based on availability of latents + if full_latents_available or all_indv_latents_available: + mode = "joint" + elif prompt_latents_available: + mode = "text" + elif image_latents_available: + mode = "img" + else: + # No inputs or latents available + mode = "joint" + + # Give warnings for ambiguous cases + if self.mode is None and prompt_available and image_available: + logger.warning( + f"You have supplied both a text prompt and image to the pipeline and mode has not been set manually," + f" defaulting to mode '{mode}'." + ) + + if self.mode is None and not input_available: + if vae_latents_available != clip_latents_available: + # Exactly one of vae_latents and clip_latents is supplied + logger.warning( + f"You have supplied exactly one of `vae_latents` and `clip_latents`, whereas either both or none" + f" are expected to be supplied. Defaulting to mode '{mode}'." + ) + elif not prompt_latents_available and not vae_latents_available and not clip_latents_available: + # No inputs or latents supplied + logger.warning( + f"No inputs or latents have been supplied, and mode has not been manually set," + f" defaulting to mode '{mode}'." + ) + + return mode + + # Functions to manually set the mode + def set_text_mode(self): + r"""Manually set the generation mode to unconditional ("marginal") text generation.""" + self.mode = "text" + + def set_image_mode(self): + r"""Manually set the generation mode to unconditional ("marginal") image generation.""" + self.mode = "img" + + def set_text_to_image_mode(self): + r"""Manually set the generation mode to text-conditioned image generation.""" + self.mode = "text2img" + + def set_image_to_text_mode(self): + r"""Manually set the generation mode to image-conditioned text generation.""" + self.mode = "img2text" + + def set_joint_mode(self): + r"""Manually set the generation mode to unconditional joint image-text generation.""" + self.mode = "joint" + + def reset_mode(self): + r"""Removes a manually set mode; after calling this, the pipeline will infer the mode from inputs.""" + self.mode = None + + def _infer_batch_size( + self, + mode, + prompt, + prompt_embeds, + image, + num_images_per_prompt, + num_prompts_per_image, + latents, + prompt_latents, + vae_latents, + clip_latents, + ): + r"""Infers the batch size and multiplier depending on mode and supplied arguments to `__call__`.""" + if num_images_per_prompt is None: + num_images_per_prompt = 1 + if num_prompts_per_image is None: + num_prompts_per_image = 1 + + assert num_images_per_prompt > 0, "num_images_per_prompt must be a positive integer" + assert num_prompts_per_image > 0, "num_prompts_per_image must be a positive integer" + + if mode in ["text2img"]: + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + # Either prompt or prompt_embeds must be present for text2img. + batch_size = prompt_embeds.shape[0] + multiplier = num_images_per_prompt + elif mode in ["img2text"]: + if isinstance(image, PIL.Image.Image): + batch_size = 1 + else: + # Image must be available and type either PIL.Image.Image or torch.FloatTensor. + # Not currently supporting something like image_embeds. + batch_size = image.shape[0] + multiplier = num_prompts_per_image + elif mode in ["img"]: + if vae_latents is not None: + batch_size = vae_latents.shape[0] + elif clip_latents is not None: + batch_size = clip_latents.shape[0] + else: + batch_size = 1 + multiplier = num_images_per_prompt + elif mode in ["text"]: + if prompt_latents is not None: + batch_size = prompt_latents.shape[0] + else: + batch_size = 1 + multiplier = num_prompts_per_image + elif mode in ["joint"]: + if latents is not None: + batch_size = latents.shape[0] + elif prompt_latents is not None: + batch_size = prompt_latents.shape[0] + elif vae_latents is not None: + batch_size = vae_latents.shape[0] + elif clip_latents is not None: + batch_size = clip_latents.shape[0] + else: + batch_size = 1 + + if num_images_per_prompt == num_prompts_per_image: + multiplier = num_images_per_prompt + else: + multiplier = min(num_images_per_prompt, num_prompts_per_image) + logger.warning( + f"You are using mode `{mode}` and `num_images_per_prompt`: {num_images_per_prompt} and" + f" num_prompts_per_image: {num_prompts_per_image} are not equal. Using batch size equal to" + f" `min(num_images_per_prompt, num_prompts_per_image) = {batch_size}." + ) + return batch_size, multiplier + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + # self.tokenizer => self.clip_tokenizer + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.clip_tokenizer( + prompt, + padding="max_length", + max_length=self.clip_tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.clip_tokenizer.batch_decode( + untruncated_ids[:, self.clip_tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.clip_tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.clip_tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.prepare_image_latents + # Add num_prompts_per_image argument, sample from autoencoder moment distribution + def encode_image_vae_latents( + self, + image, + batch_size, + num_prompts_per_image, + dtype, + device, + do_classifier_free_guidance, + generator=None, + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_prompts_per_image + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + * self.vae.config.scaling_factor + for i in range(batch_size) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + # Scale image_latents by the VAE's scaling factor + image_latents = image_latents * self.vae.config.scaling_factor + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + + return image_latents + + def encode_image_clip_latents( + self, + image, + batch_size, + num_prompts_per_image, + dtype, + device, + generator=None, + ): + # Map image to CLIP embedding. + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + preprocessed_image = self.image_processor.preprocess( + image, + return_tensors="pt", + ) + preprocessed_image = preprocessed_image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_prompts_per_image + if isinstance(generator, list): + image_latents = [ + self.image_encoder(**preprocessed_image[i : i + 1]).image_embeds for i in range(batch_size) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.image_encoder(**preprocessed_image).image_embeds + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + return image_latents + + # Note that the CLIP latents are not decoded for image generation. + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + # Rename: decode_latents -> decode_image_latents + def decode_image_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_text_latents( + self, batch_size, num_images_per_prompt, seq_len, hidden_size, dtype, device, generator, latents=None + ): + # Prepare latents for the CLIP embedded prompt. + shape = (batch_size * num_images_per_prompt, seq_len, hidden_size) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + # latents is assumed to have shace (B, L, D) + latents = latents.repeat(num_images_per_prompt, 1, 1) + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + # Rename prepare_latents -> prepare_image_vae_latents and add num_prompts_per_image argument. + def prepare_image_vae_latents( + self, + batch_size, + num_prompts_per_image, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size * num_prompts_per_image, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + # latents is assumed to have shape (B, C, H, W) + latents = latents.repeat(num_prompts_per_image, 1, 1, 1) + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image_clip_latents( + self, batch_size, num_prompts_per_image, clip_img_dim, dtype, device, generator, latents=None + ): + # Prepare latents for the CLIP embedded image. + shape = (batch_size * num_prompts_per_image, 1, clip_img_dim) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + # latents is assumed to have shape (B, L, D) + latents = latents.repeat(num_prompts_per_image, 1, 1) + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _split(self, x, height, width): + r""" + Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim) into two tensors of shape (B, C, H, W) + and (B, 1, clip_img_dim) + """ + batch_size = x.shape[0] + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + img_vae_dim = self.num_channels_latents * latent_height * latent_width + + img_vae, img_clip = x.split([img_vae_dim, self.image_encoder_projection_dim], dim=1) + + img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)) + img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim)) + return img_vae, img_clip + + def _combine(self, img_vae, img_clip): + r""" + Combines a latent iamge img_vae of shape (B, C, H, W) and a CLIP-embedded image img_clip of shape (B, 1, + clip_img_dim) into a single tensor of shape (B, C * H * W + clip_img_dim). + """ + img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1)) + img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1)) + return torch.concat([img_vae, img_clip], dim=-1) + + def _split_joint(self, x, height, width): + r""" + Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim + text_seq_len * text_dim] into (img_vae, + img_clip, text) where img_vae is of shape (B, C, H, W), img_clip is of shape (B, 1, clip_img_dim), and text is + of shape (B, text_seq_len, text_dim). + """ + batch_size = x.shape[0] + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + img_vae_dim = self.num_channels_latents * latent_height * latent_width + text_dim = self.text_encoder_seq_len * self.text_intermediate_dim + + img_vae, img_clip, text = x.split([img_vae_dim, self.image_encoder_projection_dim, text_dim], dim=1) + + img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)) + img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim)) + text = torch.reshape(text, (batch_size, self.text_encoder_seq_len, self.text_intermediate_dim)) + return img_vae, img_clip, text + + def _combine_joint(self, img_vae, img_clip, text): + r""" + Combines a latent image img_vae of shape (B, C, H, W), a CLIP-embedded image img_clip of shape (B, L_img, + clip_img_dim), and a text embedding text of shape (B, L_text, text_dim) into a single embedding x of shape (B, + C * H * W + L_img * clip_img_dim + L_text * text_dim). + """ + img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1)) + img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1)) + text = torch.reshape(text, (text.shape[0], -1)) + return torch.concat([img_vae, img_clip, text], dim=-1) + + def _get_noise_pred( + self, + mode, + latents, + t, + prompt_embeds, + img_vae, + img_clip, + max_timestep, + data_type, + guidance_scale, + generator, + device, + height, + width, + ): + r""" + Gets the noise prediction using the `unet` and performs classifier-free guidance, if necessary. + """ + if mode == "joint": + # Joint text-image generation + img_vae_latents, img_clip_latents, text_latents = self._split_joint(latents, height, width) + + img_vae_out, img_clip_out, text_out = self.unet( + img_vae_latents, img_clip_latents, text_latents, timestep_img=t, timestep_text=t, data_type=data_type + ) + + x_out = self._combine_joint(img_vae_out, img_clip_out, text_out) + + if guidance_scale <= 1.0: + return x_out + + # Classifier-free guidance + img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype) + img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype) + text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + + _, _, text_out_uncond = self.unet( + img_vae_T, img_clip_T, text_latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type + ) + + img_vae_out_uncond, img_clip_out_uncond, _ = self.unet( + img_vae_latents, + img_clip_latents, + text_T, + timestep_img=t, + timestep_text=max_timestep, + data_type=data_type, + ) + + x_out_uncond = self._combine_joint(img_vae_out_uncond, img_clip_out_uncond, text_out_uncond) + + return guidance_scale * x_out + (1.0 - guidance_scale) * x_out_uncond + elif mode == "text2img": + # Text-conditioned image generation + img_vae_latents, img_clip_latents = self._split(latents, height, width) + + img_vae_out, img_clip_out, text_out = self.unet( + img_vae_latents, img_clip_latents, prompt_embeds, timestep_img=t, timestep_text=0, data_type=data_type + ) + + img_out = self._combine(img_vae_out, img_clip_out) + + if guidance_scale <= 1.0: + return img_out + + # Classifier-free guidance + text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + + img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( + img_vae_latents, + img_clip_latents, + text_T, + timestep_img=t, + timestep_text=max_timestep, + data_type=data_type, + ) + + img_out_uncond = self._combine(img_vae_out_uncond, img_clip_out_uncond) + + return guidance_scale * img_out + (1.0 - guidance_scale) * img_out_uncond + elif mode == "img2text": + # Image-conditioned text generation + img_vae_out, img_clip_out, text_out = self.unet( + img_vae, img_clip, latents, timestep_img=0, timestep_text=t, data_type=data_type + ) + + if guidance_scale <= 1.0: + return text_out + + # Classifier-free guidance + img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype) + img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype) + + img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( + img_vae_T, img_clip_T, latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type + ) + + return guidance_scale * text_out + (1.0 - guidance_scale) * text_out_uncond + elif mode == "text": + # Unconditional ("marginal") text generation (no CFG) + img_vae_out, img_clip_out, text_out = self.unet( + img_vae, img_clip, latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type + ) + + return text_out + elif mode == "img": + # Unconditional ("marginal") image generation (no CFG) + img_vae_latents, img_clip_latents = self._split(latents, height, width) + + img_vae_out, img_clip_out, text_out = self.unet( + img_vae_latents, + img_clip_latents, + prompt_embeds, + timestep_img=t, + timestep_text=max_timestep, + data_type=data_type, + ) + + img_out = self._combine(img_vae_out, img_clip_out) + return img_out + + def check_latents_shape(self, latents_name, latents, expected_shape): + latents_shape = latents.shape + expected_num_dims = len(expected_shape) + 1 # expected dimensions plus the batch dimension + expected_shape_str = ", ".join(str(dim) for dim in expected_shape) + if len(latents_shape) != expected_num_dims: + raise ValueError( + f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape" + f" {latents_shape} has {len(latents_shape)} dimensions." + ) + for i in range(1, expected_num_dims): + if latents_shape[i] != expected_shape[i - 1]: + raise ValueError( + f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape" + f" {latents_shape} has {latents_shape[i]} != {expected_shape[i - 1]} at dimension {i}." + ) + + def check_inputs( + self, + mode, + prompt, + image, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + latents=None, + prompt_latents=None, + vae_latents=None, + clip_latents=None, + ): + # Check inputs before running the generative process. + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if mode == "text2img": + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if mode == "img2text": + if image is None: + raise ValueError("`img2text` mode requires an image to be provided.") + + # Check provided latents + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + full_latents_available = latents is not None + prompt_latents_available = prompt_latents is not None + vae_latents_available = vae_latents is not None + clip_latents_available = clip_latents is not None + + if full_latents_available: + individual_latents_available = ( + prompt_latents is not None or vae_latents is not None or clip_latents is not None + ) + if individual_latents_available: + logger.warning( + "You have supplied both `latents` and at least one of `prompt_latents`, `vae_latents`, and" + " `clip_latents`. The value of `latents` will override the value of any individually supplied latents." + ) + # Check shape of full latents + img_vae_dim = self.num_channels_latents * latent_height * latent_width + text_dim = self.text_encoder_seq_len * self.text_encoder_hidden_size + latents_dim = img_vae_dim + self.image_encoder_projection_dim + text_dim + latents_expected_shape = (latents_dim,) + self.check_latents_shape("latents", latents, latents_expected_shape) + + # Check individual latent shapes, if present + if prompt_latents_available: + prompt_latents_expected_shape = (self.text_encoder_seq_len, self.text_encoder_hidden_size) + self.check_latents_shape("prompt_latents", prompt_latents, prompt_latents_expected_shape) + + if vae_latents_available: + vae_latents_expected_shape = (self.num_channels_latents, latent_height, latent_width) + self.check_latents_shape("vae_latents", vae_latents, vae_latents_expected_shape) + + if clip_latents_available: + clip_latents_expected_shape = (1, self.image_encoder_projection_dim) + self.check_latents_shape("clip_latents", clip_latents, clip_latents_expected_shape) + + if mode in ["text2img", "img"] and vae_latents_available and clip_latents_available: + if vae_latents.shape[0] != clip_latents.shape[0]: + raise ValueError( + f"Both `vae_latents` and `clip_latents` are supplied, but their batch dimensions are not equal:" + f" {vae_latents.shape[0]} != {clip_latents.shape[0]}." + ) + + if mode == "joint" and prompt_latents_available and vae_latents_available and clip_latents_available: + if prompt_latents.shape[0] != vae_latents.shape[0] or prompt_latents.shape[0] != clip_latents.shape[0]: + raise ValueError( + f"All of `prompt_latents`, `vae_latents`, and `clip_latents` are supplied, but their batch" + f" dimensions are not equal: {prompt_latents.shape[0]} != {vae_latents.shape[0]}" + f" != {clip_latents.shape[0]}." + ) + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + data_type: Optional[int] = 1, + num_inference_steps: int = 50, + guidance_scale: float = 8.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + num_prompts_per_image: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_latents: Optional[torch.FloatTensor] = None, + vae_latents: Optional[torch.FloatTensor] = None, + clip_latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` + instead. Required for text-conditioned image generation (`text2img`) mode. + image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): + `Image`, or tensor representing an image batch. Required for image-conditioned text generation + (`img2text`) mode. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + data_type (`int`, *optional*, defaults to 1): + The data type (either 0 or 1). Only used if you are loading a checkpoint which supports a data type + embedding; this is added for compatibility with the UniDiffuser-v1 checkpoint. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 8.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. Note that the original [UniDiffuser + paper](https://arxiv.org/pdf/2303.06555.pdf) uses a different definition of the guidance scale `w'`, + which satisfies `w = w' + 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). Used in text-conditioned image generation (`text2img`) mode. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. Used in `text2img` (text-conditioned image generation) and + `img` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are + supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples will be generated. + num_prompts_per_image (`int`, *optional*, defaults to 1): + The number of prompts to generate per image. Used in `img2text` (image-conditioned text generation) and + `text` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are + supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples will be generated. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for joint + image-text generation. Can be used to tweak the same generation with different prompts. If not + provided, a latents tensor will be generated by sampling using the supplied random `generator`. Note + that this is assumed to be a full set of VAE, CLIP, and text latents, if supplied, this will override + the value of `prompt_latents`, `vae_latents`, and `clip_latents`. + prompt_latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for text + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + vae_latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + clip_latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. Used in text-conditioned + image generation (`text2img`) mode. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. Used in text-conditioned image generation (`text2img`) mode. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.unidiffuser.ImageTextPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.unidiffuser.ImageTextPipelineOutput`] or `tuple`: + [`pipelines.unidiffuser.ImageTextPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of generated texts. + """ + + # 0. Default height and width to unet + height = height or self.unet_resolution * self.vae_scale_factor + width = width or self.unet_resolution * self.vae_scale_factor + + # 1. Check inputs + # Recalculate mode for each call to the pipeline. + mode = self._infer_mode(prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents) + self.check_inputs( + mode, + prompt, + image, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + latents, + prompt_latents, + vae_latents, + clip_latents, + ) + + # 2. Define call parameters + batch_size, multiplier = self._infer_batch_size( + mode, + prompt, + prompt_embeds, + image, + num_images_per_prompt, + num_prompts_per_image, + latents, + prompt_latents, + vae_latents, + clip_latents, + ) + device = self._execution_device + reduce_text_emb_dim = self.text_intermediate_dim < self.text_encoder_hidden_size or self.mode != "text2img" + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + # Note that this differs from the formulation in the unidiffusers paper! + # do_classifier_free_guidance = guidance_scale > 1.0 + + # check if scheduler is in sigmas space + # scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") + + # 3. Encode input prompt, if available; otherwise prepare text latents + if latents is not None: + # Overwrite individual latents + vae_latents, clip_latents, prompt_latents = self._split_joint(latents, height, width) + + if mode in ["text2img"]: + # 3.1. Encode input prompt, if available + assert prompt is not None or prompt_embeds is not None + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=multiplier, + do_classifier_free_guidance=False, # don't support standard classifier-free guidance for now + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + else: + # 3.2. Prepare text latent variables, if input not available + prompt_embeds = self.prepare_text_latents( + batch_size=batch_size, + num_images_per_prompt=multiplier, + seq_len=self.text_encoder_seq_len, + hidden_size=self.text_encoder_hidden_size, + dtype=self.text_encoder.dtype, # Should work with both full precision and mixed precision + device=device, + generator=generator, + latents=prompt_latents, + ) + + if reduce_text_emb_dim: + prompt_embeds = self.text_decoder.encode(prompt_embeds) + + # 4. Encode image, if available; otherwise prepare image latents + if mode in ["img2text"]: + # 4.1. Encode images, if available + assert image is not None, "`img2text` requires a conditioning image" + # Encode image using VAE + image_vae = preprocess(image) + height, width = image_vae.shape[-2:] + image_vae_latents = self.encode_image_vae_latents( + image=image_vae, + batch_size=batch_size, + num_prompts_per_image=multiplier, + dtype=prompt_embeds.dtype, + device=device, + do_classifier_free_guidance=False, # Copied from InstructPix2Pix, don't use their version of CFG + generator=generator, + ) + + # Encode image using CLIP + image_clip_latents = self.encode_image_clip_latents( + image=image, + batch_size=batch_size, + num_prompts_per_image=multiplier, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + ) + # (batch_size, clip_hidden_size) => (batch_size, 1, clip_hidden_size) + image_clip_latents = image_clip_latents.unsqueeze(1) + else: + # 4.2. Prepare image latent variables, if input not available + # Prepare image VAE latents in latent space + image_vae_latents = self.prepare_image_vae_latents( + batch_size=batch_size, + num_prompts_per_image=multiplier, + num_channels_latents=self.num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=vae_latents, + ) + + # Prepare image CLIP latents + image_clip_latents = self.prepare_image_clip_latents( + batch_size=batch_size, + num_prompts_per_image=multiplier, + clip_img_dim=self.image_encoder_projection_dim, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=clip_latents, + ) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + # max_timestep = timesteps[0] + max_timestep = self.scheduler.config.num_train_timesteps + + # 6. Prepare latent variables + if mode == "joint": + latents = self._combine_joint(image_vae_latents, image_clip_latents, prompt_embeds) + elif mode in ["text2img", "img"]: + latents = self._combine(image_vae_latents, image_clip_latents) + elif mode in ["img2text", "text"]: + latents = prompt_embeds + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + logger.debug(f"Scheduler extra step kwargs: {extra_step_kwargs}") + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # predict the noise residual + # Also applies classifier-free guidance as described in the UniDiffuser paper + noise_pred = self._get_noise_pred( + mode, + latents, + t, + prompt_embeds, + image_vae_latents, + image_clip_latents, + max_timestep, + data_type, + guidance_scale, + generator, + device, + height, + width, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing + gen_image = None + gen_text = None + if mode == "joint": + image_vae_latents, image_clip_latents, text_latents = self._split_joint(latents, height, width) + + # Map latent VAE image back to pixel space + gen_image = self.decode_image_latents(image_vae_latents) + + # Generate text using the text decoder + output_token_list, seq_lengths = self.text_decoder.generate_captions( + text_latents, self.text_tokenizer.eos_token_id, device=device + ) + output_list = output_token_list.cpu().numpy() + gen_text = [ + self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True) + for output, length in zip(output_list, seq_lengths) + ] + elif mode in ["text2img", "img"]: + image_vae_latents, image_clip_latents = self._split(latents, height, width) + gen_image = self.decode_image_latents(image_vae_latents) + elif mode in ["img2text", "text"]: + text_latents = latents + output_token_list, seq_lengths = self.text_decoder.generate_captions( + text_latents, self.text_tokenizer.eos_token_id, device=device + ) + output_list = output_token_list.cpu().numpy() + gen_text = [ + self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True) + for output, length in zip(output_list, seq_lengths) + ] + + # 10. Convert to PIL + if output_type == "pil" and gen_image is not None: + gen_image = self.numpy_to_pil(gen_image) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (gen_image, gen_text) + + return ImageTextPipelineOutput(images=gen_image, text=gen_text) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index ea6a61cf7587..95d07c081ccd 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -152,6 +152,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class ImageTextPipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class KandinskyImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -632,6 +647,51 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class UniDiffuserModel(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class UniDiffuserPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class UniDiffuserTextDecoder(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class VersatileDiffusionDualGuidedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/unidiffuser/__init__.py b/tests/pipelines/unidiffuser/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/unidiffuser/test_unidiffuser.py b/tests/pipelines/unidiffuser/test_unidiffuser.py new file mode 100644 index 000000000000..f9f798ebe55d --- /dev/null +++ b/tests/pipelines/unidiffuser/test_unidiffuser.py @@ -0,0 +1,670 @@ +import gc +import random +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + GPT2Tokenizer, +) + +from diffusers import ( + AutoencoderKL, + DPMSolverMultistepScheduler, + UniDiffuserModel, + UniDiffuserPipeline, + UniDiffuserTextDecoder, +) +from diffusers.utils import floats_tensor, load_image, randn_tensor, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu + +from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = UniDiffuserPipeline + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + + def get_dummy_components(self): + unet = UniDiffuserModel.from_pretrained( + "hf-internal-testing/unidiffuser-diffusers-test", + subfolder="unet", + ) + + scheduler = DPMSolverMultistepScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + solver_order=3, + ) + + vae = AutoencoderKL.from_pretrained( + "hf-internal-testing/unidiffuser-diffusers-test", + subfolder="vae", + ) + + text_encoder = CLIPTextModel.from_pretrained( + "hf-internal-testing/unidiffuser-diffusers-test", + subfolder="text_encoder", + ) + clip_tokenizer = CLIPTokenizer.from_pretrained( + "hf-internal-testing/unidiffuser-diffusers-test", + subfolder="clip_tokenizer", + ) + + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "hf-internal-testing/unidiffuser-diffusers-test", + subfolder="image_encoder", + ) + # From the Stable Diffusion Image Variation pipeline tests + image_processor = CLIPImageProcessor(crop_size=32, size=32) + # image_processor = CLIPImageProcessor.from_pretrained("hf-internal-testing/tiny-random-clip") + + text_tokenizer = GPT2Tokenizer.from_pretrained( + "hf-internal-testing/unidiffuser-diffusers-test", + subfolder="text_tokenizer", + ) + text_decoder = UniDiffuserTextDecoder.from_pretrained( + "hf-internal-testing/unidiffuser-diffusers-test", + subfolder="text_decoder", + ) + + components = { + "vae": vae, + "text_encoder": text_encoder, + "image_encoder": image_encoder, + "image_processor": image_processor, + "clip_tokenizer": clip_tokenizer, + "text_decoder": text_decoder, + "text_tokenizer": text_tokenizer, + "unet": unet, + "scheduler": scheduler, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + image = Image.fromarray(np.uint8(image)).convert("RGB") + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "an elephant under the sea", + "image": image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + } + return inputs + + def get_fixed_latents(self, device, seed=0): + if type(device) == str: + device = torch.device(device) + generator = torch.Generator(device=device).manual_seed(seed) + # Hardcode the shapes for now. + prompt_latents = randn_tensor((1, 77, 32), generator=generator, device=device, dtype=torch.float32) + vae_latents = randn_tensor((1, 4, 16, 16), generator=generator, device=device, dtype=torch.float32) + clip_latents = randn_tensor((1, 1, 32), generator=generator, device=device, dtype=torch.float32) + + latents = { + "prompt_latents": prompt_latents, + "vae_latents": vae_latents, + "clip_latents": clip_latents, + } + return latents + + def get_dummy_inputs_with_latents(self, device, seed=0): + # image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + # image = image.cpu().permute(0, 2, 3, 1)[0] + # image = Image.fromarray(np.uint8(image)).convert("RGB") + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg", + ) + image = image.resize((32, 32)) + latents = self.get_fixed_latents(device, seed=seed) + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "an elephant under the sea", + "image": image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + "prompt_latents": latents.get("prompt_latents"), + "vae_latents": latents.get("vae_latents"), + "clip_latents": latents.get("clip_latents"), + } + return inputs + + def test_unidiffuser_default_joint_v0(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'joint' + unidiffuser_pipe.set_joint_mode() + assert unidiffuser_pipe.mode == "joint" + + # inputs = self.get_dummy_inputs(device) + inputs = self.get_dummy_inputs_with_latents(device) + # Delete prompt and image for joint inference. + del inputs["prompt"] + del inputs["image"] + sample = unidiffuser_pipe(**inputs) + image = sample.images + text = sample.text + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_img_slice = np.array([0.5760, 0.6270, 0.6571, 0.4965, 0.4638, 0.5663, 0.5254, 0.5068, 0.5716]) + assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3 + + expected_text_prefix = " no no no " + assert text[0][:10] == expected_text_prefix + + def test_unidiffuser_default_joint_no_cfg_v0(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'joint' + unidiffuser_pipe.set_joint_mode() + assert unidiffuser_pipe.mode == "joint" + + # inputs = self.get_dummy_inputs(device) + inputs = self.get_dummy_inputs_with_latents(device) + # Delete prompt and image for joint inference. + del inputs["prompt"] + del inputs["image"] + # Set guidance scale to 1.0 to turn off CFG + inputs["guidance_scale"] = 1.0 + sample = unidiffuser_pipe(**inputs) + image = sample.images + text = sample.text + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_img_slice = np.array([0.5760, 0.6270, 0.6571, 0.4965, 0.4638, 0.5663, 0.5254, 0.5068, 0.5716]) + assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3 + + expected_text_prefix = " no no no " + assert text[0][:10] == expected_text_prefix + + def test_unidiffuser_default_text2img_v0(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'text2img' + unidiffuser_pipe.set_text_to_image_mode() + assert unidiffuser_pipe.mode == "text2img" + + inputs = self.get_dummy_inputs_with_latents(device) + # Delete image for text-conditioned image generation + del inputs["image"] + image = unidiffuser_pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.5758, 0.6269, 0.6570, 0.4967, 0.4639, 0.5664, 0.5257, 0.5067, 0.5715]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_unidiffuser_default_image_0(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'img' + unidiffuser_pipe.set_image_mode() + assert unidiffuser_pipe.mode == "img" + + inputs = self.get_dummy_inputs(device) + # Delete prompt and image for unconditional ("marginal") text generation. + del inputs["prompt"] + del inputs["image"] + image = unidiffuser_pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.5760, 0.6270, 0.6571, 0.4966, 0.4638, 0.5663, 0.5254, 0.5068, 0.5715]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_unidiffuser_default_text_v0(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'img' + unidiffuser_pipe.set_text_mode() + assert unidiffuser_pipe.mode == "text" + + inputs = self.get_dummy_inputs(device) + # Delete prompt and image for unconditional ("marginal") text generation. + del inputs["prompt"] + del inputs["image"] + text = unidiffuser_pipe(**inputs).text + + expected_text_prefix = " no no no " + assert text[0][:10] == expected_text_prefix + + def test_unidiffuser_default_img2text_v0(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'img2text' + unidiffuser_pipe.set_image_to_text_mode() + assert unidiffuser_pipe.mode == "img2text" + + inputs = self.get_dummy_inputs_with_latents(device) + # Delete text for image-conditioned text generation + del inputs["prompt"] + text = unidiffuser_pipe(**inputs).text + + expected_text_prefix = " no no no " + assert text[0][:10] == expected_text_prefix + + def test_unidiffuser_default_joint_v1(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("hf-internal-testing/unidiffuser-test-v1") + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'joint' + unidiffuser_pipe.set_joint_mode() + assert unidiffuser_pipe.mode == "joint" + + # inputs = self.get_dummy_inputs(device) + inputs = self.get_dummy_inputs_with_latents(device) + # Delete prompt and image for joint inference. + del inputs["prompt"] + del inputs["image"] + inputs["data_type"] = 1 + sample = unidiffuser_pipe(**inputs) + image = sample.images + text = sample.text + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_img_slice = np.array([0.5760, 0.6270, 0.6571, 0.4965, 0.4638, 0.5663, 0.5254, 0.5068, 0.5716]) + assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3 + + expected_text_prefix = " no no no " + assert text[0][:10] == expected_text_prefix + + def test_unidiffuser_default_text2img_v1(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("hf-internal-testing/unidiffuser-test-v1") + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'text2img' + unidiffuser_pipe.set_text_to_image_mode() + assert unidiffuser_pipe.mode == "text2img" + + inputs = self.get_dummy_inputs_with_latents(device) + # Delete image for text-conditioned image generation + del inputs["image"] + image = unidiffuser_pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.5758, 0.6269, 0.6570, 0.4967, 0.4639, 0.5664, 0.5257, 0.5067, 0.5715]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_unidiffuser_default_img2text_v1(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("hf-internal-testing/unidiffuser-test-v1") + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'img2text' + unidiffuser_pipe.set_image_to_text_mode() + assert unidiffuser_pipe.mode == "img2text" + + inputs = self.get_dummy_inputs_with_latents(device) + # Delete text for image-conditioned text generation + del inputs["prompt"] + text = unidiffuser_pipe(**inputs).text + + expected_text_prefix = " no no no " + assert text[0][:10] == expected_text_prefix + + def test_unidiffuser_text2img_multiple_images(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'text2img' + unidiffuser_pipe.set_text_to_image_mode() + assert unidiffuser_pipe.mode == "text2img" + + inputs = self.get_dummy_inputs(device) + # Delete image for text-conditioned image generation + del inputs["image"] + inputs["num_images_per_prompt"] = 2 + inputs["num_prompts_per_image"] = 3 + image = unidiffuser_pipe(**inputs).images + assert image.shape == (2, 32, 32, 3) + + def test_unidiffuser_img2text_multiple_prompts(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'img2text' + unidiffuser_pipe.set_image_to_text_mode() + assert unidiffuser_pipe.mode == "img2text" + + inputs = self.get_dummy_inputs(device) + # Delete text for image-conditioned text generation + del inputs["prompt"] + inputs["num_images_per_prompt"] = 2 + inputs["num_prompts_per_image"] = 3 + text = unidiffuser_pipe(**inputs).text + + assert len(text) == 3 + + def test_unidiffuser_text2img_multiple_images_with_latents(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'text2img' + unidiffuser_pipe.set_text_to_image_mode() + assert unidiffuser_pipe.mode == "text2img" + + inputs = self.get_dummy_inputs_with_latents(device) + # Delete image for text-conditioned image generation + del inputs["image"] + inputs["num_images_per_prompt"] = 2 + inputs["num_prompts_per_image"] = 3 + image = unidiffuser_pipe(**inputs).images + assert image.shape == (2, 32, 32, 3) + + def test_unidiffuser_img2text_multiple_prompts_with_latents(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'img2text' + unidiffuser_pipe.set_image_to_text_mode() + assert unidiffuser_pipe.mode == "img2text" + + inputs = self.get_dummy_inputs_with_latents(device) + # Delete text for image-conditioned text generation + del inputs["prompt"] + inputs["num_images_per_prompt"] = 2 + inputs["num_prompts_per_image"] = 3 + text = unidiffuser_pipe(**inputs).text + + assert len(text) == 3 + + @require_torch_gpu + def test_unidiffuser_default_joint_v1_cuda_fp16(self): + device = "cuda" + unidiffuser_pipe = UniDiffuserPipeline.from_pretrained( + "hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16 + ) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'joint' + unidiffuser_pipe.set_joint_mode() + assert unidiffuser_pipe.mode == "joint" + + inputs = self.get_dummy_inputs_with_latents(device) + # Delete prompt and image for joint inference. + del inputs["prompt"] + del inputs["image"] + inputs["data_type"] = 1 + sample = unidiffuser_pipe(**inputs) + image = sample.images + text = sample.text + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_img_slice = np.array([0.5049, 0.5498, 0.5854, 0.3052, 0.4460, 0.6489, 0.5122, 0.4810, 0.6138]) + assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3 + + expected_text_prefix = '" This This' + assert text[0][: len(expected_text_prefix)] == expected_text_prefix + + @require_torch_gpu + def test_unidiffuser_default_text2img_v1_cuda_fp16(self): + device = "cuda" + unidiffuser_pipe = UniDiffuserPipeline.from_pretrained( + "hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16 + ) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'text2img' + unidiffuser_pipe.set_text_to_image_mode() + assert unidiffuser_pipe.mode == "text2img" + + inputs = self.get_dummy_inputs_with_latents(device) + # Delete prompt and image for joint inference. + del inputs["image"] + inputs["data_type"] = 1 + sample = unidiffuser_pipe(**inputs) + image = sample.images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_img_slice = np.array([0.5054, 0.5498, 0.5854, 0.3052, 0.4458, 0.6489, 0.5122, 0.4810, 0.6138]) + assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3 + + @require_torch_gpu + def test_unidiffuser_default_img2text_v1_cuda_fp16(self): + device = "cuda" + unidiffuser_pipe = UniDiffuserPipeline.from_pretrained( + "hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16 + ) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'img2text' + unidiffuser_pipe.set_image_to_text_mode() + assert unidiffuser_pipe.mode == "img2text" + + inputs = self.get_dummy_inputs_with_latents(device) + # Delete prompt and image for joint inference. + del inputs["prompt"] + inputs["data_type"] = 1 + text = unidiffuser_pipe(**inputs).text + + expected_text_prefix = '" This This' + assert text[0][: len(expected_text_prefix)] == expected_text_prefix + + +@slow +@require_torch_gpu +class UniDiffuserPipelineSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, device, seed=0, generate_latents=False): + generator = torch.manual_seed(seed) + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg" + ) + inputs = { + "prompt": "an elephant under the sea", + "image": image, + "generator": generator, + "num_inference_steps": 3, + "guidance_scale": 8.0, + "output_type": "numpy", + } + if generate_latents: + latents = self.get_fixed_latents(device, seed=seed) + for latent_name, latent_tensor in latents.items(): + inputs[latent_name] = latent_tensor + return inputs + + def get_fixed_latents(self, device, seed=0): + if type(device) == str: + device = torch.device(device) + latent_device = torch.device("cpu") + generator = torch.Generator(device=latent_device).manual_seed(seed) + # Hardcode the shapes for now. + prompt_latents = randn_tensor((1, 77, 768), generator=generator, device=device, dtype=torch.float32) + vae_latents = randn_tensor((1, 4, 64, 64), generator=generator, device=device, dtype=torch.float32) + clip_latents = randn_tensor((1, 1, 512), generator=generator, device=device, dtype=torch.float32) + + # Move latents onto desired device. + prompt_latents = prompt_latents.to(device) + vae_latents = vae_latents.to(device) + clip_latents = clip_latents.to(device) + + latents = { + "prompt_latents": prompt_latents, + "vae_latents": vae_latents, + "clip_latents": clip_latents, + } + return latents + + def test_unidiffuser_default_joint_v1(self): + pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1") + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + # inputs = self.get_dummy_inputs(device) + inputs = self.get_inputs(device=torch_device, generate_latents=True) + # Delete prompt and image for joint inference. + del inputs["prompt"] + del inputs["image"] + sample = pipe(**inputs) + image = sample.images + text = sample.text + assert image.shape == (1, 512, 512, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_img_slice = np.array([0.2402, 0.2375, 0.2285, 0.2378, 0.2407, 0.2263, 0.2354, 0.2307, 0.2520]) + assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-1 + + expected_text_prefix = "A living room" + assert text[0][: len(expected_text_prefix)] == expected_text_prefix + + def test_unidiffuser_default_text2img_v1(self): + pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1") + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs(device=torch_device, generate_latents=True) + del inputs["image"] + sample = pipe(**inputs) + image = sample.images + assert image.shape == (1, 512, 512, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.0242, 0.0103, 0.0022, 0.0129, 0.0000, 0.0090, 0.0376, 0.0508, 0.0005]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1 + + def test_unidiffuser_default_img2text_v1(self): + pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1") + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs(device=torch_device, generate_latents=True) + del inputs["prompt"] + sample = pipe(**inputs) + text = sample.text + + expected_text_prefix = "An astronaut" + assert text[0][: len(expected_text_prefix)] == expected_text_prefix + + def test_unidiffuser_default_joint_v1_fp16(self): + pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + # inputs = self.get_dummy_inputs(device) + inputs = self.get_inputs(device=torch_device, generate_latents=True) + # Delete prompt and image for joint inference. + del inputs["prompt"] + del inputs["image"] + sample = pipe(**inputs) + image = sample.images + text = sample.text + assert image.shape == (1, 512, 512, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_img_slice = np.array([0.2402, 0.2375, 0.2285, 0.2378, 0.2407, 0.2263, 0.2354, 0.2307, 0.2520]) + assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-1 + + expected_text_prefix = "A living room" + assert text[0][: len(expected_text_prefix)] == expected_text_prefix + + def test_unidiffuser_default_text2img_v1_fp16(self): + pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs(device=torch_device, generate_latents=True) + del inputs["image"] + sample = pipe(**inputs) + image = sample.images + assert image.shape == (1, 512, 512, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.0242, 0.0103, 0.0022, 0.0129, 0.0000, 0.0090, 0.0376, 0.0508, 0.0005]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1 + + def test_unidiffuser_default_img2text_v1_fp16(self): + pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs(device=torch_device, generate_latents=True) + del inputs["prompt"] + sample = pipe(**inputs) + text = sample.text + + expected_text_prefix = "An astronaut" + assert text[0][: len(expected_text_prefix)] == expected_text_prefix