diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 72808df049c9..db9e72a4ea20 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -184,6 +184,8 @@ title: Audio Diffusion - local: api/pipelines/audioldm title: AudioLDM + - local: api/pipelines/consistency_models + title: Consistency Models - local: api/pipelines/controlnet title: ControlNet - local: api/pipelines/cycle_diffusion @@ -274,6 +276,8 @@ - sections: - local: api/schedulers/overview title: Overview + - local: api/schedulers/cm_stochastic_iterative + title: Consistency Model Multistep Scheduler - local: api/schedulers/ddim title: DDIM - local: api/schedulers/ddim_inverse diff --git a/docs/source/en/api/pipelines/consistency_models.mdx b/docs/source/en/api/pipelines/consistency_models.mdx new file mode 100644 index 000000000000..715743b87a12 --- /dev/null +++ b/docs/source/en/api/pipelines/consistency_models.mdx @@ -0,0 +1,87 @@ +# Consistency Models + +Consistency Models were proposed in [Consistency Models](https://arxiv.org/abs/2303.01469) by Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever. + +The abstract of the [paper](https://arxiv.org/pdf/2303.01469.pdf) is as follows: + +*Diffusion models have significantly advanced the fields of image, audio, and video generation, but they depend on an iterative sampling process that causes slow generation. To overcome this limitation, we propose consistency models, a new family of models that generate high quality samples by directly mapping noise to data. They support fast one-step generation by design, while still allowing multistep sampling to trade compute for sample quality. They also support zero-shot data editing, such as image inpainting, colorization, and super-resolution, without requiring explicit training on these tasks. Consistency models can be trained either by distilling pre-trained diffusion models, or as standalone generative models altogether. Through extensive experiments, we demonstrate that they outperform existing distillation techniques for diffusion models in one- and few-step sampling, achieving the new state-of-the-art FID of 3.55 on CIFAR-10 and 6.20 on ImageNet 64x64 for one-step generation. When trained in isolation, consistency models become a new family of generative models that can outperform existing one-step, non-adversarial generative models on standard benchmarks such as CIFAR-10, ImageNet 64x64 and LSUN 256x256. * + +Resources: + +* [Paper](https://arxiv.org/abs/2303.01469) +* [Original Code](https://github.com/openai/consistency_models) + +Available Checkpoints are: +- *cd_imagenet64_l2 (64x64 resolution)* [openai/consistency-model-pipelines](https://huggingface.co/openai/consistency-model-pipelines) +- *cd_imagenet64_lpips (64x64 resolution)* [openai/diffusers-cd_imagenet64_lpips](https://huggingface.co/openai/diffusers-cd_imagenet64_lpips) +- *ct_imagenet64 (64x64 resolution)* [openai/diffusers-ct_imagenet64](https://huggingface.co/openai/diffusers-ct_imagenet64) +- *cd_bedroom256_l2 (256x256 resolution)* [openai/diffusers-cd_bedroom256_l2](https://huggingface.co/openai/diffusers-cd_bedroom256_l2) +- *cd_bedroom256_lpips (256x256 resolution)* [openai/diffusers-cd_bedroom256_lpips](https://huggingface.co/openai/diffusers-cd_bedroom256_lpips) +- *ct_bedroom256 (256x256 resolution)* [openai/diffusers-ct_bedroom256](https://huggingface.co/openai/diffusers-ct_bedroom256) +- *cd_cat256_l2 (256x256 resolution)* [openai/diffusers-cd_cat256_l2](https://huggingface.co/openai/diffusers-cd_cat256_l2) +- *cd_cat256_lpips (256x256 resolution)* [openai/diffusers-cd_cat256_lpips](https://huggingface.co/openai/diffusers-cd_cat256_lpips) +- *ct_cat256 (256x256 resolution)* [openai/diffusers-ct_cat256](https://huggingface.co/openai/diffusers-ct_cat256) + +## Available Pipelines + +| Pipeline | Tasks | Demo | Colab | +|:---:|:---:|:---:|:---:| +| [ConsistencyModelPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_consistency_models.py) | *Unconditional Image Generation* | | | + +This pipeline was contributed by our community members [dg845](https://github.com/dg845) and [ayushtues](https://huggingface.co/ayushtues) :heart: + +## Usage Example + +```python +import torch + +from diffusers import ConsistencyModelPipeline + +device = "cuda" +# Load the cd_imagenet64_l2 checkpoint. +model_id_or_path = "openai/diffusers-cd_imagenet64_l2" +pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) +pipe.to(device) + +# Onestep Sampling +image = pipe(num_inference_steps=1).images[0] +image.save("consistency_model_onestep_sample.png") + +# Onestep sampling, class-conditional image generation +# ImageNet-64 class label 145 corresponds to king penguins +image = pipe(num_inference_steps=1, class_labels=145).images[0] +image.save("consistency_model_onestep_sample_penguin.png") + +# Multistep sampling, class-conditional image generation +# Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo. +# https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77 +image = pipe(timesteps=[22, 0], class_labels=145).images[0] +image.save("consistency_model_multistep_sample_penguin.png") +``` + +For an additional speed-up, one can also make use of `torch.compile`. Multiple images can be generated in <1 second as follows: + +```py +import torch +from diffusers import ConsistencyModelPipeline + +device = "cuda" +# Load the cd_bedroom256_lpips checkpoint. +model_id_or_path = "openai/diffusers-cd_bedroom256_lpips" +pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) +pipe.to(device) + +pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + +# Multistep sampling +# Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo: +# https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L83 +for _ in range(10): + image = pipe(timesteps=[17, 0]).images[0] + image.show() +``` + +## ConsistencyModelPipeline +[[autodoc]] ConsistencyModelPipeline + - all + - __call__ diff --git a/docs/source/en/api/schedulers/cm_stochastic_iterative.mdx b/docs/source/en/api/schedulers/cm_stochastic_iterative.mdx new file mode 100644 index 000000000000..0cc40bde47a0 --- /dev/null +++ b/docs/source/en/api/schedulers/cm_stochastic_iterative.mdx @@ -0,0 +1,11 @@ +# Consistency Model Multistep Scheduler + +## Overview + +Multistep and onestep scheduler (Algorithm 1) introduced alongside consistency models in the paper [Consistency Models](https://arxiv.org/abs/2303.01469) by Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever. +Based on the [original consistency models implementation](https://github.com/openai/consistency_models). +Should generate good samples from [`ConsistencyModelPipeline`] in one or a small number of steps. + +## CMStochasticIterativeScheduler +[[autodoc]] CMStochasticIterativeScheduler + diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py new file mode 100644 index 000000000000..5a6158bb9867 --- /dev/null +++ b/scripts/convert_consistency_to_diffusers.py @@ -0,0 +1,313 @@ +import argparse +import os + +import torch + +from diffusers import ( + CMStochasticIterativeScheduler, + ConsistencyModelPipeline, + UNet2DModel, +) + + +TEST_UNET_CONFIG = { + "sample_size": 32, + "in_channels": 3, + "out_channels": 3, + "layers_per_block": 2, + "num_class_embeds": 1000, + "block_out_channels": [32, 64], + "attention_head_dim": 8, + "down_block_types": [ + "ResnetDownsampleBlock2D", + "AttnDownBlock2D", + ], + "up_block_types": [ + "AttnUpBlock2D", + "ResnetUpsampleBlock2D", + ], + "resnet_time_scale_shift": "scale_shift", + "upsample_type": "resnet", + "downsample_type": "resnet", +} + +IMAGENET_64_UNET_CONFIG = { + "sample_size": 64, + "in_channels": 3, + "out_channels": 3, + "layers_per_block": 3, + "num_class_embeds": 1000, + "block_out_channels": [192, 192 * 2, 192 * 3, 192 * 4], + "attention_head_dim": 64, + "down_block_types": [ + "ResnetDownsampleBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + ], + "up_block_types": [ + "AttnUpBlock2D", + "AttnUpBlock2D", + "AttnUpBlock2D", + "ResnetUpsampleBlock2D", + ], + "resnet_time_scale_shift": "scale_shift", + "upsample_type": "resnet", + "downsample_type": "resnet", +} + +LSUN_256_UNET_CONFIG = { + "sample_size": 256, + "in_channels": 3, + "out_channels": 3, + "layers_per_block": 2, + "num_class_embeds": None, + "block_out_channels": [256, 256, 256 * 2, 256 * 2, 256 * 4, 256 * 4], + "attention_head_dim": 64, + "down_block_types": [ + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + ], + "up_block_types": [ + "AttnUpBlock2D", + "AttnUpBlock2D", + "AttnUpBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ], + "resnet_time_scale_shift": "default", + "upsample_type": "resnet", + "downsample_type": "resnet", +} + +CD_SCHEDULER_CONFIG = { + "num_train_timesteps": 40, + "sigma_min": 0.002, + "sigma_max": 80.0, +} + +CT_IMAGENET_64_SCHEDULER_CONFIG = { + "num_train_timesteps": 201, + "sigma_min": 0.002, + "sigma_max": 80.0, +} + +CT_LSUN_256_SCHEDULER_CONFIG = { + "num_train_timesteps": 151, + "sigma_min": 0.002, + "sigma_max": 80.0, +} + + +def str2bool(v): + """ + https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse + """ + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("boolean value expected") + + +def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=False): + new_checkpoint[f"{new_prefix}.norm1.weight"] = checkpoint[f"{old_prefix}.in_layers.0.weight"] + new_checkpoint[f"{new_prefix}.norm1.bias"] = checkpoint[f"{old_prefix}.in_layers.0.bias"] + new_checkpoint[f"{new_prefix}.conv1.weight"] = checkpoint[f"{old_prefix}.in_layers.2.weight"] + new_checkpoint[f"{new_prefix}.conv1.bias"] = checkpoint[f"{old_prefix}.in_layers.2.bias"] + new_checkpoint[f"{new_prefix}.time_emb_proj.weight"] = checkpoint[f"{old_prefix}.emb_layers.1.weight"] + new_checkpoint[f"{new_prefix}.time_emb_proj.bias"] = checkpoint[f"{old_prefix}.emb_layers.1.bias"] + new_checkpoint[f"{new_prefix}.norm2.weight"] = checkpoint[f"{old_prefix}.out_layers.0.weight"] + new_checkpoint[f"{new_prefix}.norm2.bias"] = checkpoint[f"{old_prefix}.out_layers.0.bias"] + new_checkpoint[f"{new_prefix}.conv2.weight"] = checkpoint[f"{old_prefix}.out_layers.3.weight"] + new_checkpoint[f"{new_prefix}.conv2.bias"] = checkpoint[f"{old_prefix}.out_layers.3.bias"] + + if has_skip: + new_checkpoint[f"{new_prefix}.conv_shortcut.weight"] = checkpoint[f"{old_prefix}.skip_connection.weight"] + new_checkpoint[f"{new_prefix}.conv_shortcut.bias"] = checkpoint[f"{old_prefix}.skip_connection.bias"] + + return new_checkpoint + + +def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_dim=None): + weight_q, weight_k, weight_v = checkpoint[f"{old_prefix}.qkv.weight"].chunk(3, dim=0) + bias_q, bias_k, bias_v = checkpoint[f"{old_prefix}.qkv.bias"].chunk(3, dim=0) + + new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"] + new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"] + + new_checkpoint[f"{new_prefix}.to_q.weight"] = weight_q.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_q.bias"] = bias_q.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_k.weight"] = weight_k.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_k.bias"] = bias_k.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1) + + new_checkpoint[f"{new_prefix}.to_out.0.weight"] = ( + checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1) + ) + new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1) + + return new_checkpoint + + +def con_pt_to_diffuser(checkpoint_path: str, unet_config): + checkpoint = torch.load(checkpoint_path, map_location="cpu") + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"] + + if unet_config["num_class_embeds"] is not None: + new_checkpoint["class_embedding.weight"] = checkpoint["label_emb.weight"] + + new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"] + + down_block_types = unet_config["down_block_types"] + layers_per_block = unet_config["layers_per_block"] + attention_head_dim = unet_config["attention_head_dim"] + channels_list = unet_config["block_out_channels"] + current_layer = 1 + prev_channels = channels_list[0] + + for i, layer_type in enumerate(down_block_types): + current_channels = channels_list[i] + downsample_block_has_skip = current_channels != prev_channels + if layer_type == "ResnetDownsampleBlock2D": + for j in range(layers_per_block): + new_prefix = f"down_blocks.{i}.resnets.{j}" + old_prefix = f"input_blocks.{current_layer}.0" + has_skip = True if j == 0 and downsample_block_has_skip else False + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip) + current_layer += 1 + + elif layer_type == "AttnDownBlock2D": + for j in range(layers_per_block): + new_prefix = f"down_blocks.{i}.resnets.{j}" + old_prefix = f"input_blocks.{current_layer}.0" + has_skip = True if j == 0 and downsample_block_has_skip else False + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip) + new_prefix = f"down_blocks.{i}.attentions.{j}" + old_prefix = f"input_blocks.{current_layer}.1" + new_checkpoint = convert_attention( + checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim + ) + current_layer += 1 + + if i != len(down_block_types) - 1: + new_prefix = f"down_blocks.{i}.downsamplers.0" + old_prefix = f"input_blocks.{current_layer}.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + current_layer += 1 + + prev_channels = current_channels + + # hardcoded the mid-block for now + new_prefix = "mid_block.resnets.0" + old_prefix = "middle_block.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + new_prefix = "mid_block.attentions.0" + old_prefix = "middle_block.1" + new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim) + new_prefix = "mid_block.resnets.1" + old_prefix = "middle_block.2" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + + current_layer = 0 + up_block_types = unet_config["up_block_types"] + + for i, layer_type in enumerate(up_block_types): + if layer_type == "ResnetUpsampleBlock2D": + for j in range(layers_per_block + 1): + new_prefix = f"up_blocks.{i}.resnets.{j}" + old_prefix = f"output_blocks.{current_layer}.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True) + current_layer += 1 + + if i != len(up_block_types) - 1: + new_prefix = f"up_blocks.{i}.upsamplers.0" + old_prefix = f"output_blocks.{current_layer-1}.1" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + elif layer_type == "AttnUpBlock2D": + for j in range(layers_per_block + 1): + new_prefix = f"up_blocks.{i}.resnets.{j}" + old_prefix = f"output_blocks.{current_layer}.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True) + new_prefix = f"up_blocks.{i}.attentions.{j}" + old_prefix = f"output_blocks.{current_layer}.1" + new_checkpoint = convert_attention( + checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim + ) + current_layer += 1 + + if i != len(up_block_types) - 1: + new_prefix = f"up_blocks.{i}.upsamplers.0" + old_prefix = f"output_blocks.{current_layer-1}.2" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + + new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"] + new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"] + new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"] + + return new_checkpoint + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--unet_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.") + parser.add_argument( + "--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model." + ) + parser.add_argument("--class_cond", default=True, type=str, help="Whether the model is class-conditional.") + + args = parser.parse_args() + args.class_cond = str2bool(args.class_cond) + + ckpt_name = os.path.basename(args.unet_path) + print(f"Checkpoint: {ckpt_name}") + + # Get U-Net config + if "imagenet64" in ckpt_name: + unet_config = IMAGENET_64_UNET_CONFIG + elif "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)): + unet_config = LSUN_256_UNET_CONFIG + elif "test" in ckpt_name: + unet_config = TEST_UNET_CONFIG + else: + raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.") + + if not args.class_cond: + unet_config["num_class_embeds"] = None + + converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, unet_config) + + image_unet = UNet2DModel(**unet_config) + image_unet.load_state_dict(converted_unet_ckpt) + + # Get scheduler config + if "cd" in ckpt_name or "test" in ckpt_name: + scheduler_config = CD_SCHEDULER_CONFIG + elif "ct" in ckpt_name and "imagenet64" in ckpt_name: + scheduler_config = CT_IMAGENET_64_SCHEDULER_CONFIG + elif "ct" in ckpt_name and "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)): + scheduler_config = CT_LSUN_256_SCHEDULER_CONFIG + else: + raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.") + + cm_scheduler = CMStochasticIterativeScheduler(**scheduler_config) + + consistency_model = ConsistencyModelPipeline(unet=image_unet, scheduler=cm_scheduler) + consistency_model.save_pretrained(args.dump_path) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 764f9204dffb..f0c25edd3fdc 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -58,6 +58,7 @@ ) from .pipelines import ( AudioPipelineOutput, + ConsistencyModelPipeline, DanceDiffusionPipeline, DDIMPipeline, DDPMPipeline, @@ -72,6 +73,7 @@ ScoreSdeVePipeline, ) from .schedulers import ( + CMStochasticIterativeScheduler, DDIMInverseScheduler, DDIMParallelScheduler, DDIMScheduler, diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 7077aa889190..3b17acd3d829 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -66,6 +66,10 @@ class UNet2DModel(ModelMixin, ConfigMixin): layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. + downsample_type (`str`, *optional*, defaults to `conv`): + The downsample type for downsampling layers. Choose between "conv" and "resnet" + upsample_type (`str`, *optional*, defaults to `conv`): + The upsample type for upsampling layers. Choose between "conv" and "resnet" act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. @@ -96,6 +100,8 @@ def __init__( layers_per_block: int = 2, mid_block_scale_factor: float = 1, downsample_padding: int = 1, + downsample_type: str = "conv", + upsample_type: str = "conv", act_fn: str = "silu", attention_head_dim: Optional[int] = 8, norm_num_groups: int = 32, @@ -168,6 +174,7 @@ def __init__( attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, ) self.down_blocks.append(down_block) @@ -207,6 +214,7 @@ def __init__( resnet_groups=norm_num_groups, attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, ) self.up_blocks.append(up_block) prev_output_channel = output_channel diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index eee7e6023e88..d4e7bd4e03f7 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -51,6 +51,7 @@ def get_down_block( resnet_out_scale_factor=1.0, cross_attention_norm=None, attention_head_dim=None, + downsample_type=None, ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: @@ -88,18 +89,22 @@ def get_down_block( output_scale_factor=resnet_out_scale_factor, ) elif down_block_type == "AttnDownBlock2D": + if add_downsample is False: + downsample_type = None + else: + downsample_type = downsample_type or "conv" # default to 'conv' return AttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, - add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, ) elif down_block_type == "CrossAttnDownBlock2D": if cross_attention_dim is None: @@ -239,6 +244,7 @@ def get_up_block( resnet_out_scale_factor=1.0, cross_attention_norm=None, attention_head_dim=None, + upsample_type=None, ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: @@ -319,18 +325,23 @@ def get_up_block( cross_attention_norm=cross_attention_norm, ) elif up_block_type == "AttnUpBlock2D": + if add_upsample is False: + upsample_type = None + else: + upsample_type = upsample_type or "conv" # default to 'conv' + return AttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, - add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, ) elif up_block_type == "SkipUpBlock2D": return SkipUpBlock2D( @@ -747,11 +758,12 @@ def __init__( attention_head_dim=1, output_scale_factor=1.0, downsample_padding=1, - add_downsample=True, + downsample_type="conv", ): super().__init__() resnets = [] attentions = [] + self.downsample_type = downsample_type if attention_head_dim is None: logger.warn( @@ -793,7 +805,7 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - if add_downsample: + if downsample_type == "conv": self.downsamplers = nn.ModuleList( [ Downsample2D( @@ -801,6 +813,24 @@ def __init__( ) ] ) + elif downsample_type == "resnet": + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + down=True, + ) + ] + ) else: self.downsamplers = None @@ -810,11 +840,14 @@ def forward(self, hidden_states, temb=None, upsample_size=None): for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + if self.downsample_type == "resnet": + hidden_states = downsampler(hidden_states, temb=temb) + else: + hidden_states = downsampler(hidden_states) output_states += (hidden_states,) @@ -1860,12 +1893,14 @@ def __init__( resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=1.0, - add_upsample=True, + upsample_type="conv", ): super().__init__() resnets = [] attentions = [] + self.upsample_type = upsample_type + if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." @@ -1908,8 +1943,26 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - if add_upsample: + if upsample_type == "conv": self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + elif upsample_type == "resnet": + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + up=True, + ) + ] + ) else: self.upsamplers = None @@ -1925,7 +1978,10 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + if self.upsample_type == "resnet": + hidden_states = upsampler(hidden_states, temb=temb) + else: + hidden_states = upsampler(hidden_states) return hidden_states diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ca57756c6aa4..3926b3413e01 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -16,6 +16,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: + from .consistency_models import ConsistencyModelPipeline from .dance_diffusion import DanceDiffusionPipeline from .ddim import DDIMPipeline from .ddpm import DDPMPipeline diff --git a/src/diffusers/pipelines/consistency_models/__init__.py b/src/diffusers/pipelines/consistency_models/__init__.py new file mode 100644 index 000000000000..fd78ddb3aae2 --- /dev/null +++ b/src/diffusers/pipelines/consistency_models/__init__.py @@ -0,0 +1 @@ +from .pipeline_consistency_models import ConsistencyModelPipeline diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py new file mode 100644 index 000000000000..4e72e3fdbafe --- /dev/null +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -0,0 +1,337 @@ +from typing import Callable, List, Optional, Union + +import torch + +from ...models import UNet2DModel +from ...schedulers import CMStochasticIterativeScheduler +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + + >>> from diffusers import ConsistencyModelPipeline + + >>> device = "cuda" + >>> # Load the cd_imagenet64_l2 checkpoint. + >>> model_id_or_path = "openai/diffusers-cd_imagenet64_l2" + >>> pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) + >>> pipe.to(device) + + >>> # Onestep Sampling + >>> image = pipe(num_inference_steps=1).images[0] + >>> image.save("cd_imagenet64_l2_onestep_sample.png") + + >>> # Onestep sampling, class-conditional image generation + >>> # ImageNet-64 class label 145 corresponds to king penguins + >>> image = pipe(num_inference_steps=1, class_labels=145).images[0] + >>> image.save("cd_imagenet64_l2_onestep_sample_penguin.png") + + >>> # Multistep sampling, class-conditional image generation + >>> # Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo: + >>> # https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77 + >>> image = pipe(num_inference_steps=None, timesteps=[22, 0], class_labels=145).images[0] + >>> image.save("cd_imagenet64_l2_multistep_sample_penguin.png") + ``` +""" + + +class ConsistencyModelPipeline(DiffusionPipeline): + r""" + Pipeline for consistency models for unconditional or class-conditional image generation, as introduced in [1]. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + [1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models" + https://arxiv.org/pdf/2303.01469 + + Args: + unet ([`UNet2DModel`]): + Unconditional or class-conditional U-Net architecture to denoise image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the image latents. Currently only compatible + with [`CMStochasticIterativeScheduler`]. + """ + + def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler) -> None: + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + ) + + self.safety_checker = None + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.unet]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Follows diffusers.VaeImageProcessor.postprocess + def postprocess_image(self, sample: torch.FloatTensor, output_type: str = "pil"): + if output_type not in ["pt", "np", "pil"]: + raise ValueError( + f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']" + ) + + # Equivalent to diffusers.VaeImageProcessor.denormalize + sample = (sample / 2 + 0.5).clamp(0, 1) + if output_type == "pt": + return sample + + # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy + sample = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "np": + return sample + + # Output_type must be 'pil' + sample = self.numpy_to_pil(sample) + return sample + + def prepare_class_labels(self, batch_size, device, class_labels=None): + if self.unet.config.num_class_embeds is not None: + if isinstance(class_labels, list): + class_labels = torch.tensor(class_labels, dtype=torch.int) + elif isinstance(class_labels, int): + assert batch_size == 1, "Batch size must be 1 if classes is an int" + class_labels = torch.tensor([class_labels], dtype=torch.int) + elif class_labels is None: + # Randomly generate batch_size class labels + # TODO: should use generator here? int analogue of randn_tensor is not exposed in ...utils + class_labels = torch.randint(0, self.unet.config.num_class_embeds, size=(batch_size,)) + class_labels = class_labels.to(device) + else: + class_labels = None + return class_labels + + def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps): + if num_inference_steps is None and timesteps is None: + raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") + + if num_inference_steps is not None and timesteps is not None: + logger.warning( + f"Both `num_inference_steps`: {num_inference_steps} and `timesteps`: {timesteps} are supplied;" + " `timesteps` will be used over `num_inference_steps`." + ) + + if latents is not None: + expected_shape = (batch_size, 3, img_size, img_size) + if latents.shape != expected_shape: + raise ValueError(f"The shape of latents is {latents.shape} but is expected to be {expected_shape}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + batch_size: int = 1, + class_labels: Optional[Union[torch.Tensor, List[int], int]] = None, + num_inference_steps: int = 1, + timesteps: List[int] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ): + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + class_labels (`torch.Tensor` or `List[int]` or `int`, *optional*): + Optional class labels for conditioning class-conditional consistency models. Will not be used if the + model is not class-conditional. + num_inference_steps (`int`, *optional*, defaults to 1): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Prepare call parameters + img_size = self.unet.config.sample_size + device = self._execution_device + + # 1. Check inputs + self.check_inputs(num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps) + + # 2. Prepare image latents + # Sample image latents x_0 ~ N(0, sigma_0^2 * I) + sample = self.prepare_latents( + batch_size=batch_size, + num_channels=self.unet.config.in_channels, + height=img_size, + width=img_size, + dtype=self.unet.dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 3. Handle class_labels for class-conditional models + class_labels = self.prepare_class_labels(batch_size, device, class_labels=class_labels) + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps + + # 5. Denoising loop + # Multistep sampling: implements Algorithm 1 in the paper + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + scaled_sample = self.scheduler.scale_model_input(sample, t) + model_output = self.unet(scaled_sample, t, class_labels=class_labels, return_dict=False)[0] + + sample = self.scheduler.step(model_output, t, sample, generator=generator)[0] + + # call the callback, if provided + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, sample) + + # 6. Post-process image sample + image = self.postprocess_image(sample, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 935759bbb6af..0a07ce4baed2 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -28,6 +28,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: + from .scheduling_consistency_models import CMStochasticIterativeScheduler from .scheduling_ddim import DDIMScheduler from .scheduling_ddim_inverse import DDIMInverseScheduler from .scheduling_ddim_parallel import DDIMParallelScheduler diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py new file mode 100644 index 000000000000..fb296054d65b --- /dev/null +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -0,0 +1,380 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging, randn_tensor +from .scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class CMStochasticIterativeSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): + """ + Multistep and onestep sampling for consistency models from Song et al. 2023 [1]. This implements Algorithm 1 in the + paper [1]. + + [1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models" + https://arxiv.org/pdf/2303.01469 [2] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based + Generative Models." https://arxiv.org/abs/2206.00364 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + sigma_min (`float`): + Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the original implementation. + sigma_max (`float`): + Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the original implementation. + sigma_data (`float`): + The standard deviation of the data distribution, following the EDM paper [2]. This was set to 0.5 in the + original implementation, which is also the original value suggested in the EDM paper. + s_noise (`float`): + The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000, + 1.011]. This was set to 1.0 in the original implementation. + rho (`float`): + The rho parameter used for calculating the Karras sigma schedule, introduced in the EDM paper [2]. This was + set to 7.0 in the original implementation, which is also the original value suggested in the EDM paper. + clip_denoised (`bool`): + Whether to clip the denoised outputs to `(-1, 1)`. Defaults to `True`. + timesteps (`List` or `np.ndarray` or `torch.Tensor`, *optional*): + Optionally, an explicit timestep schedule can be specified. The timesteps are expected to be in increasing + order. + """ + + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 40, + sigma_min: float = 0.002, + sigma_max: float = 80.0, + sigma_data: float = 0.5, + s_noise: float = 1.0, + rho: float = 7.0, + clip_denoised: bool = True, + ): + # standard deviation of the initial noise distribution + self.init_noise_sigma = sigma_max + + ramp = np.linspace(0, 1, num_train_timesteps) + sigmas = self._convert_to_karras(ramp) + timesteps = self.sigma_to_t(sigmas) + + # setable values + self.num_inference_steps = None + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps) + self.custom_timesteps = False + self.is_scale_input_called = False + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + return indices.item() + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] + ) -> torch.FloatTensor: + """ + Scales the consistency model input by `(sigma**2 + sigma_data**2) ** 0.5`, following the EDM model. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain + Returns: + `torch.FloatTensor`: scaled input sample + """ + # Get sigma corresponding to timestep + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_idx = self.index_for_timestep(timestep) + sigma = self.sigmas[step_idx] + + sample = sample / ((sigma**2 + self.config.sigma_data**2) ** 0.5) + + self.is_scale_input_called = True + return sample + + def sigma_to_t(self, sigmas: Union[float, np.ndarray]): + """ + Gets scaled timesteps from the Karras sigmas, for input to the consistency model. + + Args: + sigmas (`float` or `np.ndarray`): single Karras sigma or array of Karras sigmas + Returns: + `float` or `np.ndarray`: scaled input timestep or scaled input timestep array + """ + if not isinstance(sigmas, np.ndarray): + sigmas = np.array(sigmas, dtype=np.float64) + + timesteps = 1000 * 0.25 * np.log(sigmas + 1e-44) + + return timesteps + + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, optional): + custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps` + must be `None`. + """ + if num_inference_steps is None and timesteps is None: + raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") + + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `timesteps`.") + + # Follow DDPMScheduler custom timesteps logic + if timesteps is not None: + for i in range(1, len(timesteps)): + if timesteps[i] >= timesteps[i - 1]: + raise ValueError("`timesteps` must be in descending order.") + + if timesteps[0] >= self.config.num_train_timesteps: + raise ValueError( + f"`timesteps` must start before `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps}." + ) + + timesteps = np.array(timesteps, dtype=np.int64) + self.custom_timesteps = True + else: + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.custom_timesteps = False + + # Map timesteps to Karras sigmas directly for multistep sampling + # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675 + num_train_timesteps = self.config.num_train_timesteps + ramp = timesteps[::-1].copy() + ramp = ramp / (num_train_timesteps - 1) + sigmas = self._convert_to_karras(ramp) + timesteps = self.sigma_to_t(sigmas) + + sigmas = np.concatenate([sigmas, [self.sigma_min]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=device) + + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) + else: + self.timesteps = torch.from_numpy(timesteps).to(device=device) + + # Modified _convert_to_karras implementation that takes in ramp as argument + def _convert_to_karras(self, ramp): + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = self.config.sigma_min + sigma_max: float = self.config.sigma_max + + rho = self.config.rho + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def get_scalings(self, sigma): + sigma_data = self.config.sigma_data + + c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) + c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + return c_skip, c_out + + def get_scalings_for_boundary_condition(self, sigma): + """ + Gets the scalings used in the consistency model parameterization, following Appendix C of the original paper. + This enforces the consistency model boundary condition. + + Note that `epsilon` in the equations for c_skip and c_out is set to sigma_min. + + Args: + sigma (`torch.FloatTensor`): + The current sigma in the Karras sigma schedule. + Returns: + `tuple`: + A two-element tuple where c_skip (which weights the current sample) is the first element and c_out + (which weights the consistency model output) is the second element. + """ + sigma_min = self.config.sigma_min + sigma_data = self.config.sigma_data + + c_skip = sigma_data**2 / ((sigma - sigma_min) ** 2 + sigma_data**2) + c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + return c_skip, c_out + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`float`): current timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + generator (`torch.Generator`, *optional*): Random number generator. + return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class + Returns: + [`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + f" `{self.__class__}.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + logger.warning( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + sigma_min = self.config.sigma_min + sigma_max = self.config.sigma_max + + step_index = self.index_for_timestep(timestep) + + # sigma_next corresponds to next_t in original implementation + sigma = self.sigmas[step_index] + if step_index + 1 < self.config.num_train_timesteps: + sigma_next = self.sigmas[step_index + 1] + else: + # Set sigma_next to sigma_min + sigma_next = self.sigmas[-1] + + # Get scalings for boundary conditions + c_skip, c_out = self.get_scalings_for_boundary_condition(sigma) + + # 1. Denoise model output using boundary conditions + denoised = c_out * model_output + c_skip * sample + if self.config.clip_denoised: + denoised = denoised.clamp(-1, 1) + + # 2. Sample z ~ N(0, s_noise^2 * I) + # Noise is not used for onestep sampling. + if len(self.timesteps) > 1: + noise = randn_tensor( + model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator + ) + else: + noise = torch.zeros_like(model_output) + z = noise * self.config.s_noise + + sigma_hat = sigma_next.clamp(min=sigma_min, max=sigma_max) + + # 3. Return noisy sample + # tau = sigma_hat, eps = sigma_min + prev_sample = denoised + z * (sigma_hat**2 - sigma_min**2) ** 0.5 + + if not return_dict: + return (prev_sample,) + + return CMStochasticIterativeSchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 7a13bc89e883..20dbf84681d3 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -210,6 +210,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ConsistencyModelPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DanceDiffusionPipeline(metaclass=DummyObject): _backends = ["torch"] @@ -390,6 +405,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class CMStochasticIterativeScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DDIMInverseScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/consistency_models/__init__.py b/tests/pipelines/consistency_models/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py new file mode 100644 index 000000000000..8dce90318505 --- /dev/null +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -0,0 +1,288 @@ +import gc +import unittest + +import numpy as np +import torch +from torch.backends.cuda import sdp_kernel + +from diffusers import ( + CMStochasticIterativeScheduler, + ConsistencyModelPipeline, + UNet2DModel, +) +from diffusers.utils import randn_tensor, slow, torch_device +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_2, require_torch_gpu + +from ..pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class ConsistencyModelPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = ConsistencyModelPipeline + params = UNCONDITIONAL_IMAGE_GENERATION_PARAMS + batch_params = UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS + + # Override required_optional_params to remove num_images_per_prompt + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "output_type", + "return_dict", + "callback", + "callback_steps", + ] + ) + + @property + def dummy_uncond_unet(self): + unet = UNet2DModel.from_pretrained( + "diffusers/consistency-models-test", + subfolder="test_unet", + ) + return unet + + @property + def dummy_cond_unet(self): + unet = UNet2DModel.from_pretrained( + "diffusers/consistency-models-test", + subfolder="test_unet_class_cond", + ) + return unet + + def get_dummy_components(self, class_cond=False): + if class_cond: + unet = self.dummy_cond_unet + else: + unet = self.dummy_uncond_unet + + # Default to CM multistep sampler + scheduler = CMStochasticIterativeScheduler( + num_train_timesteps=40, + sigma_min=0.002, + sigma_max=80.0, + ) + + components = { + "unet": unet, + "scheduler": scheduler, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "batch_size": 1, + "num_inference_steps": None, + "timesteps": [22, 0], + "generator": generator, + "output_type": "np", + } + + return inputs + + def test_consistency_model_pipeline_multistep(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = ConsistencyModelPipeline(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.3572, 0.6273, 0.4031, 0.3961, 0.4321, 0.5730, 0.5266, 0.4780, 0.5004]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_consistency_model_pipeline_multistep_class_cond(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components(class_cond=True) + pipe = ConsistencyModelPipeline(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["class_labels"] = 0 + image = pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.3572, 0.6273, 0.4031, 0.3961, 0.4321, 0.5730, 0.5266, 0.4780, 0.5004]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_consistency_model_pipeline_onestep(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = ConsistencyModelPipeline(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 1 + inputs["timesteps"] = None + image = pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.5004, 0.5004, 0.4994, 0.5008, 0.4976, 0.5018, 0.4990, 0.4982, 0.4987]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_consistency_model_pipeline_onestep_class_cond(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components(class_cond=True) + pipe = ConsistencyModelPipeline(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 1 + inputs["timesteps"] = None + inputs["class_labels"] = 0 + image = pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.5004, 0.5004, 0.4994, 0.5008, 0.4976, 0.5018, 0.4990, 0.4982, 0.4987]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + +@slow +@require_torch_gpu +class ConsistencyModelPipelineSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, seed=0, get_fixed_latents=False, device="cpu", dtype=torch.float32, shape=(1, 3, 64, 64)): + generator = torch.manual_seed(seed) + + inputs = { + "num_inference_steps": None, + "timesteps": [22, 0], + "class_labels": 0, + "generator": generator, + "output_type": "np", + } + + if get_fixed_latents: + latents = self.get_fixed_latents(seed=seed, device=device, dtype=dtype, shape=shape) + inputs["latents"] = latents + + return inputs + + def get_fixed_latents(self, seed=0, device="cpu", dtype=torch.float32, shape=(1, 3, 64, 64)): + if type(device) == str: + device = torch.device(device) + generator = torch.Generator(device=device).manual_seed(seed) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def test_consistency_model_cd_multistep(self): + unet = UNet2DModel.from_pretrained("diffusers/consistency_models", subfolder="diffusers_cd_imagenet64_l2") + scheduler = CMStochasticIterativeScheduler( + num_train_timesteps=40, + sigma_min=0.002, + sigma_max=80.0, + ) + pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) + pipe.to(torch_device=torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs() + image = pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + image_slice = image[0, -3:, -3:, -1] + + expected_slice = np.array([0.0888, 0.0881, 0.0666, 0.0479, 0.0292, 0.0195, 0.0201, 0.0163, 0.0254]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2 + + def test_consistency_model_cd_onestep(self): + unet = UNet2DModel.from_pretrained("diffusers/consistency_models", subfolder="diffusers_cd_imagenet64_l2") + scheduler = CMStochasticIterativeScheduler( + num_train_timesteps=40, + sigma_min=0.002, + sigma_max=80.0, + ) + pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) + pipe.to(torch_device=torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs() + inputs["num_inference_steps"] = 1 + inputs["timesteps"] = None + image = pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + image_slice = image[0, -3:, -3:, -1] + + expected_slice = np.array([0.0340, 0.0152, 0.0063, 0.0267, 0.0221, 0.0107, 0.0416, 0.0186, 0.0217]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2 + + @require_torch_2 + def test_consistency_model_cd_multistep_flash_attn(self): + unet = UNet2DModel.from_pretrained("diffusers/consistency_models", subfolder="diffusers_cd_imagenet64_l2") + scheduler = CMStochasticIterativeScheduler( + num_train_timesteps=40, + sigma_min=0.002, + sigma_max=80.0, + ) + pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) + pipe.to(torch_device=torch_device, torch_dtype=torch.float16) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(get_fixed_latents=True, device=torch_device) + # Ensure usage of flash attention in torch 2.0 + with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + image = pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + image_slice = image[0, -3:, -3:, -1] + + expected_slice = np.array([0.1875, 0.1428, 0.1289, 0.2151, 0.2092, 0.1477, 0.1877, 0.1641, 0.1353]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + @require_torch_2 + def test_consistency_model_cd_onestep_flash_attn(self): + unet = UNet2DModel.from_pretrained("diffusers/consistency_models", subfolder="diffusers_cd_imagenet64_l2") + scheduler = CMStochasticIterativeScheduler( + num_train_timesteps=40, + sigma_min=0.002, + sigma_max=80.0, + ) + pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) + pipe.to(torch_device=torch_device, torch_dtype=torch.float16) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(get_fixed_latents=True, device=torch_device) + inputs["num_inference_steps"] = 1 + inputs["timesteps"] = None + # Ensure usage of flash attention in torch 2.0 + with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + image = pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + image_slice = image[0, -3:, -3:, -1] + + expected_slice = np.array([0.1663, 0.1948, 0.2275, 0.1680, 0.1204, 0.1245, 0.1858, 0.1338, 0.2095]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 diff --git a/tests/schedulers/test_scheduler_consistency_model.py b/tests/schedulers/test_scheduler_consistency_model.py new file mode 100644 index 000000000000..66f07d024783 --- /dev/null +++ b/tests/schedulers/test_scheduler_consistency_model.py @@ -0,0 +1,150 @@ +import torch + +from diffusers import CMStochasticIterativeScheduler + +from .test_schedulers import SchedulerCommonTest + + +class CMStochasticIterativeSchedulerTest(SchedulerCommonTest): + scheduler_classes = (CMStochasticIterativeScheduler,) + num_inference_steps = 10 + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 201, + "sigma_min": 0.002, + "sigma_max": 80.0, + } + + config.update(**kwargs) + return config + + # Override test_step_shape to add CMStochasticIterativeScheduler-specific logic regarding timesteps + # Problem is that we don't know two timesteps that will always be in the timestep schedule from only the scheduler + # config; scaled sigma_max is always in the timestep schedule, but sigma_min is in the sigma schedule while scaled + # sigma_min is not in the timestep schedule + def test_step_shape(self): + num_inference_steps = 10 + + scheduler_config = self.get_scheduler_config() + scheduler = self.scheduler_classes[0](**scheduler_config) + + scheduler.set_timesteps(num_inference_steps) + + timestep_0 = scheduler.timesteps[0] + timestep_1 = scheduler.timesteps[1] + + sample = self.dummy_sample + residual = 0.1 * sample + + output_0 = scheduler.step(residual, timestep_0, sample).prev_sample + output_1 = scheduler.step(residual, timestep_1, sample).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_timesteps(self): + for timesteps in [10, 50, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_clip_denoised(self): + for clip_denoised in [True, False]: + self.check_over_configs(clip_denoised=clip_denoised) + + def test_full_loop_no_noise_onestep(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps = 1 + scheduler.set_timesteps(num_inference_steps) + timesteps = scheduler.timesteps + + generator = torch.manual_seed(0) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + + for i, t in enumerate(timesteps): + # 1. scale model input + scaled_sample = scheduler.scale_model_input(sample, t) + + # 2. predict noise residual + residual = model(scaled_sample, t) + + # 3. predict previous sample x_t-1 + pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample + + sample = pred_prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 192.7614) < 1e-2 + assert abs(result_mean.item() - 0.2510) < 1e-3 + + def test_full_loop_no_noise_multistep(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + timesteps = [106, 0] + scheduler.set_timesteps(timesteps=timesteps) + timesteps = scheduler.timesteps + + generator = torch.manual_seed(0) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + + for t in timesteps: + # 1. scale model input + scaled_sample = scheduler.scale_model_input(sample, t) + + # 2. predict noise residual + residual = model(scaled_sample, t) + + # 3. predict previous sample x_t-1 + pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample + + sample = pred_prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 347.6357) < 1e-2 + assert abs(result_mean.item() - 0.4527) < 1e-3 + + def test_custom_timesteps_increasing_order(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + timesteps = [39, 30, 12, 15, 0] + + with self.assertRaises(ValueError, msg="`timesteps` must be in descending order."): + scheduler.set_timesteps(timesteps=timesteps) + + def test_custom_timesteps_passing_both_num_inference_steps_and_timesteps(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + timesteps = [39, 30, 12, 1, 0] + num_inference_steps = len(timesteps) + + with self.assertRaises(ValueError, msg="Can only pass one of `num_inference_steps` or `timesteps`."): + scheduler.set_timesteps(num_inference_steps=num_inference_steps, timesteps=timesteps) + + def test_custom_timesteps_too_large(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + timesteps = [scheduler.config.num_train_timesteps] + + with self.assertRaises( + ValueError, + msg="`timesteps` must start before `self.config.train_timesteps`: {scheduler.config.num_train_timesteps}}", + ): + scheduler.set_timesteps(timesteps=timesteps) diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index d1ae333c0cd2..d9423d621966 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -24,6 +24,7 @@ import diffusers from diffusers import ( + CMStochasticIterativeScheduler, DDIMScheduler, DEISMultistepScheduler, DiffusionPipeline, @@ -303,6 +304,11 @@ def check_over_configs(self, time_step=0, **config): scheduler_config = self.get_scheduler_config(**config) scheduler = scheduler_class(**scheduler_config) + if scheduler_class == CMStochasticIterativeScheduler: + # Get valid timestep based on sigma_max, which should always be in timestep schedule. + scaled_sigma_max = scheduler.sigma_to_t(scheduler.config.sigma_max) + time_step = scaled_sigma_max + if scheduler_class == VQDiffusionScheduler: num_vec_classes = scheduler_config["num_vec_classes"] sample = self.dummy_sample(num_vec_classes) @@ -323,7 +329,11 @@ def check_over_configs(self, time_step=0, **config): kwargs["num_inference_steps"] = num_inference_steps # Make sure `scale_model_input` is invoked to prevent a warning - if scheduler_class != VQDiffusionScheduler: + if scheduler_class == CMStochasticIterativeScheduler: + # Get valid timestep based on sigma_max, which should always be in timestep schedule. + _ = scheduler.scale_model_input(sample, scaled_sigma_max) + _ = new_scheduler.scale_model_input(sample, scaled_sigma_max) + elif scheduler_class != VQDiffusionScheduler: _ = scheduler.scale_model_input(sample, 0) _ = new_scheduler.scale_model_input(sample, 0) @@ -393,6 +403,10 @@ def test_from_save_pretrained(self): scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) + if scheduler_class == CMStochasticIterativeScheduler: + # Get valid timestep based on sigma_max, which should always be in timestep schedule. + timestep = scheduler.sigma_to_t(scheduler.config.sigma_max) + if scheduler_class == VQDiffusionScheduler: num_vec_classes = scheduler_config["num_vec_classes"] sample = self.dummy_sample(num_vec_classes) @@ -539,6 +553,10 @@ def recursive_check(tuple_object, dict_object): scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) + if scheduler_class == CMStochasticIterativeScheduler: + # Get valid timestep based on sigma_max, which should always be in timestep schedule. + timestep = scheduler.sigma_to_t(scheduler.config.sigma_max) + if scheduler_class == VQDiffusionScheduler: num_vec_classes = scheduler_config["num_vec_classes"] sample = self.dummy_sample(num_vec_classes) @@ -594,7 +612,12 @@ def test_scheduler_public_api(self): if scheduler_class != VQDiffusionScheduler: sample = self.dummy_sample - scaled_sample = scheduler.scale_model_input(sample, 0.0) + if scheduler_class == CMStochasticIterativeScheduler: + # Get valid timestep based on sigma_max, which should always be in timestep schedule. + scaled_sigma_max = scheduler.sigma_to_t(scheduler.config.sigma_max) + scaled_sample = scheduler.scale_model_input(sample, scaled_sigma_max) + else: + scaled_sample = scheduler.scale_model_input(sample, 0.0) self.assertEqual(sample.shape, scaled_sample.shape) def test_add_noise_device(self): @@ -606,7 +629,12 @@ def test_add_noise_device(self): scheduler.set_timesteps(100) sample = self.dummy_sample.to(torch_device) - scaled_sample = scheduler.scale_model_input(sample, 0.0) + if scheduler_class == CMStochasticIterativeScheduler: + # Get valid timestep based on sigma_max, which should always be in timestep schedule. + scaled_sigma_max = scheduler.sigma_to_t(scheduler.config.sigma_max) + scaled_sample = scheduler.scale_model_input(sample, scaled_sigma_max) + else: + scaled_sample = scheduler.scale_model_input(sample, 0.0) self.assertEqual(sample.shape, scaled_sample.shape) noise = torch.randn_like(scaled_sample).to(torch_device) @@ -637,7 +665,7 @@ def test_deprecated_kwargs(self): def test_trained_betas(self): for scheduler_class in self.scheduler_classes: - if scheduler_class == VQDiffusionScheduler: + if scheduler_class in (VQDiffusionScheduler, CMStochasticIterativeScheduler): continue scheduler_config = self.get_scheduler_config()