From 7c638f5b5137d7406691d82c56c5ace6f67cb487 Mon Sep 17 00:00:00 2001 From: "Chambon, Thomas" Date: Fri, 7 Jul 2023 17:22:12 +0200 Subject: [PATCH 1/3] community pipeline: implementation of iadb --- examples/community/README.md | 60 +++++++++++++ examples/community/iadb.py | 157 +++++++++++++++++++++++++++++++++++ 2 files changed, 217 insertions(+) create mode 100644 examples/community/iadb.py diff --git a/examples/community/README.md b/examples/community/README.md index 17cd34a5182d..6967d273e449 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -38,6 +38,7 @@ If a community doesn't work as expected, please open an issue and ping the autho | Stable Diffusion IPEX Pipeline | Accelerate Stable Diffusion inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion on IPEX](#stable-diffusion-on-ipex) | - | [Yingjie Han](https://github.com/yingjie-han/) | | CLIP Guided Images Mixing Stable Diffusion Pipeline | Сombine images using usual diffusion models. | [CLIP Guided Images Mixing Using Stable Diffusion](#clip-guided-images-mixing-with-stable-diffusion) | - | [Karachev Denis](https://github.com/TheDenk) | | TensorRT Stable Diffusion Inpainting Pipeline | Accelerates the Stable Diffusion Inpainting Pipeline using TensorRT | [TensorRT Stable Diffusion Inpainting Pipeline](#tensorrt-inpainting-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) | +| IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon) To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly. ```py @@ -1707,3 +1708,62 @@ output = pipeline( ``` ![Input_Image](https://huggingface.co/datasets/kadirnar/diffusers_readme_images/resolve/main/input_image.png) ![mixture_canvas_results](https://huggingface.co/datasets/kadirnar/diffusers_readme_images/resolve/main/canvas.png) + + +### IADB pipeline + +This pipeline is the implementation of the [α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) paper. +It is a simple and minimalist diffusion model. + +The following code shows how to use the IADB pipeline to generate images using a pretrained celebahq-256 model. + +```python + +pipeline_iadb = DiffusionPipeline.from_pretrained("thomasc4/iadb-celebahq-256", custom_pipeline='iadb') + +pipeline_iadb = pipeline_iadb.to('cuda') + +output = pipeline_iadb(batch_size=4,num_inference_steps=128) +for i in range(len(output[0])): + plt.imshow(output[0][i]) + plt.show() + +``` + +Sampling with the IADB formulation is easy, and can be done in a few lines (the pipeline already implements it): + +```python + +def sample_iadb(model, x0, nb_step): + x_alpha = x0 + for t in range(nb_step): + alpha = (t/nb_step) + alpha_next =((t+1)/nb_step) + + d = model(x_alpha, torch.tensor(alpha, device=x_alpha.device))['sample'] + x_alpha = x_alpha + (alpha_next-alpha)*d + + return x_alpha + +``` + +The training loop is also straightforward: + +```python + +# Training loop +while True: + x0 = sample_noise() + x1 = sample_dataset() + + alpha = torch.rand(batch_size) + + # Blend + x_alpha = (1-alpha) * x0 + alpha * x1 + + # Loss + loss = torch.sum((D(x_alpha, alpha)- (x1-x0))**2) + optimizer.zero_grad() + loss.backward() + optimizer.step() +``` diff --git a/examples/community/iadb.py b/examples/community/iadb.py new file mode 100644 index 000000000000..72402a3d7f51 --- /dev/null +++ b/examples/community/iadb.py @@ -0,0 +1,157 @@ + +import torch +from diffusers import DiffusionPipeline +from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput +from diffusers.schedulers import DDIMScheduler +from diffusers.pipeline_utils import ImagePipelineOutput +from diffusers.configuration_utils import ConfigMixin +from typing import List, Optional, Tuple, Union + +class IADBScheduler(SchedulerMixin, ConfigMixin): + """ + IADBScheduler is a scheduler for the Iterative α-(de)Blending denoising method. It is simple and minimalist. + + For more details, see the original paper: https://arxiv.org/abs/2305.03486 and the blog post: https://ggx-research.github.io/publication/2023/05/10/publication-iadb.html + """ + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + x_alpha: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + Predict the sample at the previous timestep by reversing the ODE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. It is the direction from x0 to x1. + timestep (`float`): current timestep in the diffusion chain. + x_alpha (`torch.FloatTensor`): x_alpha sample for the current timestep + + Returns: + `torch.FloatTensor`: the sample at the previous timestep + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + + alpha = timestep / self.num_inference_steps + alpha_next = (timestep+1) / self.num_inference_steps + + d = model_output + + x_alpha = x_alpha + (alpha_next-alpha) * d + + + return x_alpha + + + def set_timesteps(self, num_inference_steps: int): + self.num_inference_steps = num_inference_steps + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + alpha: torch.FloatTensor, + ) -> torch.FloatTensor: + + return original_samples * alpha + noise * (1 - alpha) + + + def __len__(self): + return self.config.num_train_timesteps + +class IADBPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of + [`DDPMScheduler`], or [`DDIMScheduler`]. + """ + + def __init__(self, unet, scheduler): + super().__init__() + + + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + # Sample gaussian noise to begin loop + if isinstance(self.unet.config.sample_size, int): + image_shape = ( + batch_size, + self.unet.config.in_channels, + self.unet.config.sample_size, + self.unet.config.sample_size, + ) + else: + image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + image = torch.randn(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + x_alpha = image.clone() + for t in self.progress_bar(range(num_inference_steps)): + alpha = t / num_inference_steps + + # 1. predict noise model_output + model_output = self.unet(x_alpha, torch.tensor(alpha, device=x_alpha.device)).sample + + # 2. step + x_alpha = self.scheduler.step( + model_output, t, x_alpha + ) + + + + image = (x_alpha * 0.5 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) \ No newline at end of file From fa6885b4bf4f4373792b00485ddc72cafa81d92e Mon Sep 17 00:00:00 2001 From: "Chambon, Thomas" Date: Mon, 10 Jul 2023 17:10:44 +0200 Subject: [PATCH 2/3] iadb.py: reformat using black --- examples/community/iadb.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/examples/community/iadb.py b/examples/community/iadb.py index 72402a3d7f51..e97ee305e70a 100644 --- a/examples/community/iadb.py +++ b/examples/community/iadb.py @@ -1,12 +1,12 @@ - import torch from diffusers import DiffusionPipeline -from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput +from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput from diffusers.schedulers import DDIMScheduler from diffusers.pipeline_utils import ImagePipelineOutput from diffusers.configuration_utils import ConfigMixin from typing import List, Optional, Tuple, Union + class IADBScheduler(SchedulerMixin, ConfigMixin): """ IADBScheduler is a scheduler for the Iterative α-(de)Blending denoising method. It is simple and minimalist. @@ -18,7 +18,7 @@ def step( self, model_output: torch.FloatTensor, timestep: int, - x_alpha: torch.FloatTensor, + x_alpha: torch.FloatTensor, ) -> torch.FloatTensor: """ Predict the sample at the previous timestep by reversing the ODE. Core function to propagate the diffusion @@ -38,18 +38,15 @@ def step( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - alpha = timestep / self.num_inference_steps - alpha_next = (timestep+1) / self.num_inference_steps + alpha_next = (timestep + 1) / self.num_inference_steps d = model_output - x_alpha = x_alpha + (alpha_next-alpha) * d + x_alpha = x_alpha + (alpha_next - alpha) * d - return x_alpha - def set_timesteps(self, num_inference_steps: int): self.num_inference_steps = num_inference_steps @@ -59,13 +56,12 @@ def add_noise( noise: torch.FloatTensor, alpha: torch.FloatTensor, ) -> torch.FloatTensor: - return original_samples * alpha + noise * (1 - alpha) - def __len__(self): return self.config.num_train_timesteps + class IADBPipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the @@ -81,7 +77,6 @@ class IADBPipeline(DiffusionPipeline): def __init__(self, unet, scheduler): super().__init__() - self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() @@ -135,16 +130,12 @@ def __call__( x_alpha = image.clone() for t in self.progress_bar(range(num_inference_steps)): alpha = t / num_inference_steps - + # 1. predict noise model_output model_output = self.unet(x_alpha, torch.tensor(alpha, device=x_alpha.device)).sample # 2. step - x_alpha = self.scheduler.step( - model_output, t, x_alpha - ) - - + x_alpha = self.scheduler.step(model_output, t, x_alpha) image = (x_alpha * 0.5 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() @@ -154,4 +145,4 @@ def __call__( if not return_dict: return (image,) - return ImagePipelineOutput(images=image) \ No newline at end of file + return ImagePipelineOutput(images=image) From 151c814807b92358c035aba352d1cc75add7227f Mon Sep 17 00:00:00 2001 From: "Chambon, Thomas" Date: Mon, 10 Jul 2023 17:16:05 +0200 Subject: [PATCH 3/3] iadb.py: linting update --- examples/community/iadb.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/community/iadb.py b/examples/community/iadb.py index e97ee305e70a..1f421ee0ea4c 100644 --- a/examples/community/iadb.py +++ b/examples/community/iadb.py @@ -1,10 +1,11 @@ +from typing import List, Optional, Tuple, Union + import torch + from diffusers import DiffusionPipeline -from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput -from diffusers.schedulers import DDIMScheduler -from diffusers.pipeline_utils import ImagePipelineOutput from diffusers.configuration_utils import ConfigMixin -from typing import List, Optional, Tuple, Union +from diffusers.pipeline_utils import ImagePipelineOutput +from diffusers.schedulers.scheduling_utils import SchedulerMixin class IADBScheduler(SchedulerMixin, ConfigMixin):