From b7c3a033e0d3053e44e8a3abe5ebffe0a30bcf49 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sat, 23 Mar 2024 02:53:55 +0530 Subject: [PATCH 01/12] differential diffusion initial draft --- .../differential_diffusion_mixin.py | 103 ++++++++++++++++++ .../pipeline_differential_diffusion_sdxl.py | 95 ++++++++++++++++ 2 files changed, 198 insertions(+) create mode 100644 examples/research_projects/differential_diffusion/differential_diffusion_mixin.py create mode 100644 examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sdxl.py diff --git a/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py b/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py new file mode 100644 index 000000000000..03bd05f65d69 --- /dev/null +++ b/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py @@ -0,0 +1,103 @@ +import inspect +from typing import Any, Dict + +import torch + + +def _get_default_value(func, arg_name): + return inspect.signature(func).parameters[arg_name].default + + +class DifferentialDiffusionMixin: + def __init__(self): + if not hasattr(self, "prepare_latents"): + raise ValueError("`prepare_latents` must be implemented in the model class.") + + prepare_latents_possible_kwargs = inspect.signature(self.prepare_latents).parameters.keys() + prepare_latents_required_kwargs = [ + "image", + "num_inference_steps", + "num_images_per_prompt", + "generator", + "denoising_start", + ] + + if not all(kwarg in prepare_latents_possible_kwargs for kwarg in prepare_latents_required_kwargs): + raise ValueError(f"`prepare_latents` must have the following arguments: {prepare_latents_required_kwargs}") + + @torch.no_grad() + def __call__(self, map: torch.FloatTensor, **kwargs): + if map is None: + raise ValueError("`map` must be provided for differential diffusion.") + + original_with_noise = thresholds = masks = None + original_callback_on_step_end = kwargs.pop("callback_on_step_end", None) + original_callback_on_step_end_tensor_inputs = kwargs.pop("callback_on_step_end_tensor_inputs", []) + + callback_on_step_end_tensor_inputs_required = ["timesteps", "batch_size", "prompt_embeds", "device", "latents"] + callback_on_step_end_tensor_inputs = list( + set(callback_on_step_end_tensor_inputs_required + original_callback_on_step_end_tensor_inputs) + ) + + image = kwargs.pop("image", _get_default_value(self.__call__, "image")) + num_inference_steps = kwargs.pop( + "num_inference_steps", _get_default_value(self.__call__, "num_inference_steps") + ) + num_images_per_prompt = kwargs.pop( + "num_images_per_prompt", _get_default_value(self.__call__, "num_images_per_prompt") + ) + generator = kwargs.pop("generator", _get_default_value(self.__call__, "generator")) + denoising_start = kwargs.pop("denoising_start", _get_default_value(self.__call__, "denoising_start")) + + def callback(i: int, t: int, callback_kwargs: Dict[str, Any]): + nonlocal original_with_noise, thresholds, masks + + timesteps = callback_kwargs.get("timesteps") + batch_size = callback_kwargs.get("batch_size") + prompt_embeds = callback_kwargs.get("prompt_embeds") + latents = callback_kwargs.get("latents") + + if i < 0: + original_with_noise = self.prepare_latents( + image=image, + timestep=timesteps, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + dtype=prompt_embeds.dtype, + device=prompt_embeds.device, + generator=generator, + ) + thresholds = torch.arange(num_inference_steps, dtype=map.dtype) / num_inference_steps + thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(prompt_embeds.device) + masks = map > (thresholds + (denoising_start or 0)) + + if denoising_start is None: + latents = original_with_noise[:1] + elif i == 0: + pass + else: + mask = masks[i].unsqueeze(0) + mask = mask.to(latents.dtype) + mask = mask.unsqueeze(1) + latents = original_with_noise[i] * mask + latents * (1 - mask) + + callback_results = {} + + if original_callback_on_step_end is not None: + callback_kwargs["latents"] = latents + result = original_callback_on_step_end(i, t, callback_kwargs) + callback_results.update(result) + + if "latents" in result: + latents = result["latents"] + + callback_results["latents"] = latents + + return callback_results + + return super().__call__( + callback_before_step_begin=True, + callback_on_step_end=callback, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + **kwargs, + ) diff --git a/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sdxl.py b/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sdxl.py new file mode 100644 index 000000000000..222b12531bad --- /dev/null +++ b/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sdxl.py @@ -0,0 +1,95 @@ +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch + +from diffusers import StableDiffusionXLImg2ImgPipeline +from diffusers.image_processor import PipelineImageInput + +from .differential_diffusion_mixin import DifferentialDiffusionMixin + + +class DifferentialDiffusionSDXLPipeline(StableDiffusionXLImg2ImgPipeline, DifferentialDiffusionMixin): + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + strength: float = 0.3, + num_inference_steps: int = 50, + timesteps: List[int] = None, + denoising_start: Optional[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + # Differential Diffusion specific + map: torch.FloatTensor = None, + **kwargs, + ): + return DifferentialDiffusionMixin.__call__( + self, + prompt=prompt, + prompt_2=prompt_2, + image=image, + strength=strength, + num_inference_steps=num_inference_steps, + timesteps=timesteps, + denoising_start=denoising_start, + denoising_end=denoising_end, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ip_adapter_image=ip_adapter_image, + ip_adapter_image_embeds=ip_adapter_image_embeds, + output_type=output_type, + return_dict=return_dict, + cross_attention_kwargs=cross_attention_kwargs, + guidance_rescale=guidance_rescale, + original_size=original_size, + crops_coords_top_left=crops_coords_top_left, + target_size=target_size, + negative_original_size=negative_original_size, + negative_crops_coords_top_left=negative_crops_coords_top_left, + negative_target_size=negative_target_size, + aesthetic_score=aesthetic_score, + negative_aesthetic_score=negative_aesthetic_score, + clip_skip=clip_skip, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + map=map, + **kwargs, + ) From e22d00492766ffd9a42925b247c57be69029059f Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sat, 23 Mar 2024 02:58:13 +0530 Subject: [PATCH 02/12] add callback_before_step_begin flag --- .../differential_diffusion_mixin.py | 6 ++--- .../pipeline_stable_diffusion_xl_img2img.py | 24 ++++++++++++++++++- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py b/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py index 03bd05f65d69..74dc6659834c 100644 --- a/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py +++ b/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py @@ -49,7 +49,7 @@ def __call__(self, map: torch.FloatTensor, **kwargs): generator = kwargs.pop("generator", _get_default_value(self.__call__, "generator")) denoising_start = kwargs.pop("denoising_start", _get_default_value(self.__call__, "denoising_start")) - def callback(i: int, t: int, callback_kwargs: Dict[str, Any]): + def callback(pipe, i: int, t: int, callback_kwargs: Dict[str, Any]): nonlocal original_with_noise, thresholds, masks timesteps = callback_kwargs.get("timesteps") @@ -85,7 +85,7 @@ def callback(i: int, t: int, callback_kwargs: Dict[str, Any]): if original_callback_on_step_end is not None: callback_kwargs["latents"] = latents - result = original_callback_on_step_end(i, t, callback_kwargs) + result = original_callback_on_step_end(pipe, i, t, callback_kwargs) callback_results.update(result) if "latents" in result: @@ -96,8 +96,8 @@ def callback(i: int, t: int, callback_kwargs: Dict[str, Any]): return callback_results return super().__call__( - callback_before_step_begin=True, callback_on_step_end=callback, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + callback_before_step_begin=True, **kwargs, ) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 4b0ea1e3f3d1..c9c4bf8adde0 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -983,6 +983,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + callback_before_step_begin: bool = False, **kwargs, ): r""" @@ -1065,7 +1066,8 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. - ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding @@ -1135,6 +1137,8 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + callback_before_step_begin (`bool`, *optional*, defaults to `False`): + If `True`, the `callback_on_step_end` function will be called before the denoising step begins. Examples: @@ -1336,6 +1340,24 @@ def denoising_value_valid(dnv): ).to(device=device, dtype=latents.dtype) self._num_timesteps = len(timesteps) + + if callback_before_step_begin: + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, -1, timesteps, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: From 30ec3fe3a9463de351cc46a1d8fc9ec6b881830c Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sat, 23 Mar 2024 03:38:56 +0530 Subject: [PATCH 03/12] fix --- .../differential_diffusion/differential_diffusion_mixin.py | 6 ++++-- .../pipeline_differential_diffusion_sdxl.py | 3 +-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py b/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py index 74dc6659834c..93c1967b5c70 100644 --- a/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py +++ b/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py @@ -16,10 +16,12 @@ def __init__(self): prepare_latents_possible_kwargs = inspect.signature(self.prepare_latents).parameters.keys() prepare_latents_required_kwargs = [ "image", - "num_inference_steps", + "timestep", + "batch_size", "num_images_per_prompt", + "dtype", + "device", "generator", - "denoising_start", ] if not all(kwarg in prepare_latents_possible_kwargs for kwarg in prepare_latents_required_kwargs): diff --git a/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sdxl.py b/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sdxl.py index 222b12531bad..ab58492fc522 100644 --- a/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sdxl.py +++ b/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sdxl.py @@ -1,12 +1,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch +from differential_diffusion_mixin import DifferentialDiffusionMixin from diffusers import StableDiffusionXLImg2ImgPipeline from diffusers.image_processor import PipelineImageInput -from .differential_diffusion_mixin import DifferentialDiffusionMixin - class DifferentialDiffusionSDXLPipeline(StableDiffusionXLImg2ImgPipeline, DifferentialDiffusionMixin): @torch.no_grad() From fd45215ede15342b8dd3ea437e08e7f001061d33 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sat, 23 Mar 2024 03:48:31 +0530 Subject: [PATCH 04/12] fix --- .../differential_diffusion_mixin.py | 213 +++++++++--------- .../pipeline_differential_diffusion_sdxl.py | 187 ++++++++------- 2 files changed, 201 insertions(+), 199 deletions(-) diff --git a/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py b/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py index 93c1967b5c70..30ba5a586e37 100644 --- a/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py +++ b/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py @@ -1,105 +1,108 @@ -import inspect -from typing import Any, Dict - -import torch - - -def _get_default_value(func, arg_name): - return inspect.signature(func).parameters[arg_name].default - - -class DifferentialDiffusionMixin: - def __init__(self): - if not hasattr(self, "prepare_latents"): - raise ValueError("`prepare_latents` must be implemented in the model class.") - - prepare_latents_possible_kwargs = inspect.signature(self.prepare_latents).parameters.keys() - prepare_latents_required_kwargs = [ - "image", - "timestep", - "batch_size", - "num_images_per_prompt", - "dtype", - "device", - "generator", - ] - - if not all(kwarg in prepare_latents_possible_kwargs for kwarg in prepare_latents_required_kwargs): - raise ValueError(f"`prepare_latents` must have the following arguments: {prepare_latents_required_kwargs}") - - @torch.no_grad() - def __call__(self, map: torch.FloatTensor, **kwargs): - if map is None: - raise ValueError("`map` must be provided for differential diffusion.") - - original_with_noise = thresholds = masks = None - original_callback_on_step_end = kwargs.pop("callback_on_step_end", None) - original_callback_on_step_end_tensor_inputs = kwargs.pop("callback_on_step_end_tensor_inputs", []) - - callback_on_step_end_tensor_inputs_required = ["timesteps", "batch_size", "prompt_embeds", "device", "latents"] - callback_on_step_end_tensor_inputs = list( - set(callback_on_step_end_tensor_inputs_required + original_callback_on_step_end_tensor_inputs) - ) - - image = kwargs.pop("image", _get_default_value(self.__call__, "image")) - num_inference_steps = kwargs.pop( - "num_inference_steps", _get_default_value(self.__call__, "num_inference_steps") - ) - num_images_per_prompt = kwargs.pop( - "num_images_per_prompt", _get_default_value(self.__call__, "num_images_per_prompt") - ) - generator = kwargs.pop("generator", _get_default_value(self.__call__, "generator")) - denoising_start = kwargs.pop("denoising_start", _get_default_value(self.__call__, "denoising_start")) - - def callback(pipe, i: int, t: int, callback_kwargs: Dict[str, Any]): - nonlocal original_with_noise, thresholds, masks - - timesteps = callback_kwargs.get("timesteps") - batch_size = callback_kwargs.get("batch_size") - prompt_embeds = callback_kwargs.get("prompt_embeds") - latents = callback_kwargs.get("latents") - - if i < 0: - original_with_noise = self.prepare_latents( - image=image, - timestep=timesteps, - batch_size=batch_size, - num_images_per_prompt=num_images_per_prompt, - dtype=prompt_embeds.dtype, - device=prompt_embeds.device, - generator=generator, - ) - thresholds = torch.arange(num_inference_steps, dtype=map.dtype) / num_inference_steps - thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(prompt_embeds.device) - masks = map > (thresholds + (denoising_start or 0)) - - if denoising_start is None: - latents = original_with_noise[:1] - elif i == 0: - pass - else: - mask = masks[i].unsqueeze(0) - mask = mask.to(latents.dtype) - mask = mask.unsqueeze(1) - latents = original_with_noise[i] * mask + latents * (1 - mask) - - callback_results = {} - - if original_callback_on_step_end is not None: - callback_kwargs["latents"] = latents - result = original_callback_on_step_end(pipe, i, t, callback_kwargs) - callback_results.update(result) - - if "latents" in result: - latents = result["latents"] - - callback_results["latents"] = latents - - return callback_results - - return super().__call__( - callback_on_step_end=callback, - callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, - callback_before_step_begin=True, - **kwargs, - ) +import inspect +from typing import Any, Dict + +import torch + + +def _get_default_value(func, arg_name): + return inspect.signature(func).parameters[arg_name].default + + +class DifferentialDiffusionMixin: + def __init__(self): + if not hasattr(self, "prepare_latents"): + raise ValueError("`prepare_latents` must be implemented in the model class.") + + prepare_latents_possible_kwargs = inspect.signature(self.prepare_latents).parameters.keys() + prepare_latents_required_kwargs = [ + "image", + "timestep", + "batch_size", + "num_images_per_prompt", + "dtype", + "device", + "generator", + ] + + if not all(kwarg in prepare_latents_possible_kwargs for kwarg in prepare_latents_required_kwargs): + raise ValueError(f"`prepare_latents` must have the following arguments: {prepare_latents_required_kwargs}") + + @torch.no_grad() + def inference(self, map: torch.FloatTensor, **kwargs): + if map is None: + raise ValueError("`map` must be provided for differential diffusion.") + + original_with_noise = thresholds = masks = None + original_callback_on_step_end = kwargs.pop("callback_on_step_end", None) + original_callback_on_step_end_tensor_inputs = kwargs.pop("callback_on_step_end_tensor_inputs", []) + + callback_on_step_end_tensor_inputs_required = ["timesteps", "batch_size", "prompt_embeds", "device", "latents"] + callback_on_step_end_tensor_inputs = list( + set(callback_on_step_end_tensor_inputs_required + original_callback_on_step_end_tensor_inputs) + ) + + image = kwargs.pop("image", _get_default_value(self.__call__, "image")) + num_inference_steps = kwargs.pop( + "num_inference_steps", _get_default_value(self.__call__, "num_inference_steps") + ) + num_images_per_prompt = kwargs.pop( + "num_images_per_prompt", _get_default_value(self.__call__, "num_images_per_prompt") + ) + generator = kwargs.pop("generator", _get_default_value(self.__call__, "generator")) + denoising_start = kwargs.pop("denoising_start", _get_default_value(self.__call__, "denoising_start")) + callback_before_step_begin = kwargs.pop( + "callback_before_step_begin", _get_default_value(self.__call__, "callback_before_step_begin") + ) + + def callback(pipe, i: int, t: int, callback_kwargs: Dict[str, Any]): + nonlocal original_with_noise, thresholds, masks + + timesteps = callback_kwargs.get("timesteps") + batch_size = callback_kwargs.get("batch_size") + prompt_embeds = callback_kwargs.get("prompt_embeds") + latents = callback_kwargs.get("latents") + + if i < 0: + original_with_noise = self.prepare_latents( + image=image, + timestep=timesteps, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + dtype=prompt_embeds.dtype, + device=prompt_embeds.device, + generator=generator, + ) + thresholds = torch.arange(num_inference_steps, dtype=map.dtype) / num_inference_steps + thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(prompt_embeds.device) + masks = map > (thresholds + (denoising_start or 0)) + + if denoising_start is None: + latents = original_with_noise[:1] + elif i == 0: + pass + else: + mask = masks[i].unsqueeze(0) + mask = mask.to(latents.dtype) + mask = mask.unsqueeze(1) + latents = original_with_noise[i] * mask + latents * (1 - mask) + + callback_results = {} + + if original_callback_on_step_end is not None and (i >= 0 or callback_before_step_begin): + callback_kwargs["latents"] = latents + result = original_callback_on_step_end(pipe, i, t, callback_kwargs) + callback_results.update(result) + + if "latents" in result: + latents = result["latents"] + + callback_results["latents"] = latents + + return callback_results + + return super().__call__( + callback_on_step_end=callback, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + callback_before_step_begin=True, + **kwargs, + ) diff --git a/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sdxl.py b/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sdxl.py index ab58492fc522..6d1b1f32218c 100644 --- a/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sdxl.py +++ b/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sdxl.py @@ -1,94 +1,93 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import torch -from differential_diffusion_mixin import DifferentialDiffusionMixin - -from diffusers import StableDiffusionXLImg2ImgPipeline -from diffusers.image_processor import PipelineImageInput - - -class DifferentialDiffusionSDXLPipeline(StableDiffusionXLImg2ImgPipeline, DifferentialDiffusionMixin): - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - image: PipelineImageInput = None, - strength: float = 0.3, - num_inference_steps: int = 50, - timesteps: List[int] = None, - denoising_start: Optional[float] = None, - denoising_end: Optional[float] = None, - guidance_scale: float = 5.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - negative_prompt_2: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - ip_adapter_image: Optional[PipelineImageInput] = None, - ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - guidance_rescale: float = 0.0, - original_size: Tuple[int, int] = None, - crops_coords_top_left: Tuple[int, int] = (0, 0), - target_size: Tuple[int, int] = None, - negative_original_size: Optional[Tuple[int, int]] = None, - negative_crops_coords_top_left: Tuple[int, int] = (0, 0), - negative_target_size: Optional[Tuple[int, int]] = None, - aesthetic_score: float = 6.0, - negative_aesthetic_score: float = 2.5, - clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - # Differential Diffusion specific - map: torch.FloatTensor = None, - **kwargs, - ): - return DifferentialDiffusionMixin.__call__( - self, - prompt=prompt, - prompt_2=prompt_2, - image=image, - strength=strength, - num_inference_steps=num_inference_steps, - timesteps=timesteps, - denoising_start=denoising_start, - denoising_end=denoising_end, - guidance_scale=guidance_scale, - negative_prompt=negative_prompt, - negative_prompt_2=negative_prompt_2, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - latents=latents, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - ip_adapter_image=ip_adapter_image, - ip_adapter_image_embeds=ip_adapter_image_embeds, - output_type=output_type, - return_dict=return_dict, - cross_attention_kwargs=cross_attention_kwargs, - guidance_rescale=guidance_rescale, - original_size=original_size, - crops_coords_top_left=crops_coords_top_left, - target_size=target_size, - negative_original_size=negative_original_size, - negative_crops_coords_top_left=negative_crops_coords_top_left, - negative_target_size=negative_target_size, - aesthetic_score=aesthetic_score, - negative_aesthetic_score=negative_aesthetic_score, - clip_skip=clip_skip, - callback_on_step_end=callback_on_step_end, - callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, - map=map, - **kwargs, - ) +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from differential_diffusion_mixin import DifferentialDiffusionMixin + +from diffusers import StableDiffusionXLImg2ImgPipeline +from diffusers.image_processor import PipelineImageInput + + +class DifferentialDiffusionSDXLPipeline(StableDiffusionXLImg2ImgPipeline, DifferentialDiffusionMixin): + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + strength: float = 0.3, + num_inference_steps: int = 50, + timesteps: List[int] = None, + denoising_start: Optional[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + # Differential Diffusion specific + map: torch.FloatTensor = None, + **kwargs, + ): + return self.inference( + prompt=prompt, + prompt_2=prompt_2, + image=image, + strength=strength, + num_inference_steps=num_inference_steps, + timesteps=timesteps, + denoising_start=denoising_start, + denoising_end=denoising_end, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ip_adapter_image=ip_adapter_image, + ip_adapter_image_embeds=ip_adapter_image_embeds, + output_type=output_type, + return_dict=return_dict, + cross_attention_kwargs=cross_attention_kwargs, + guidance_rescale=guidance_rescale, + original_size=original_size, + crops_coords_top_left=crops_coords_top_left, + target_size=target_size, + negative_original_size=negative_original_size, + negative_crops_coords_top_left=negative_crops_coords_top_left, + negative_target_size=negative_target_size, + aesthetic_score=aesthetic_score, + negative_aesthetic_score=negative_aesthetic_score, + clip_skip=clip_skip, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + map=map, + **kwargs, + ) From d4343be518455203122daa8c82a7c3b8a35c7929 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sat, 23 Mar 2024 04:40:34 +0530 Subject: [PATCH 05/12] remove _callback_tensor_inputs restriction --- .../differential_diffusion_mixin.py | 4 ++-- .../pipeline_differential_diffusion_sdxl.py | 4 ++-- .../pipeline_stable_diffusion_xl_img2img.py | 18 ------------------ 3 files changed, 4 insertions(+), 22 deletions(-) diff --git a/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py b/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py index 30ba5a586e37..006e620d4c95 100644 --- a/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py +++ b/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py @@ -28,7 +28,7 @@ def __init__(self): raise ValueError(f"`prepare_latents` must have the following arguments: {prepare_latents_required_kwargs}") @torch.no_grad() - def inference(self, map: torch.FloatTensor, **kwargs): + def _inference(self, map: torch.FloatTensor, **kwargs): if map is None: raise ValueError("`map` must be provided for differential diffusion.") @@ -100,7 +100,7 @@ def callback(pipe, i: int, t: int, callback_kwargs: Dict[str, Any]): return callback_results - return super().__call__( + return self.__call__( callback_on_step_end=callback, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, callback_before_step_begin=True, diff --git a/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sdxl.py b/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sdxl.py index 6d1b1f32218c..c331bb328316 100644 --- a/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sdxl.py +++ b/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sdxl.py @@ -9,7 +9,7 @@ class DifferentialDiffusionSDXLPipeline(StableDiffusionXLImg2ImgPipeline, DifferentialDiffusionMixin): @torch.no_grad() - def __call__( + def inference( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, @@ -51,7 +51,7 @@ def __call__( map: torch.FloatTensor = None, **kwargs, ): - return self.inference( + return self._inference( prompt=prompt, prompt_2=prompt_2, image=image, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index c9c4bf8adde0..1e305c112acf 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -228,15 +228,6 @@ class StableDiffusionXLImg2ImgPipeline( "image_encoder", "feature_extractor", ] - _callback_tensor_inputs = [ - "latents", - "prompt_embeds", - "negative_prompt_embeds", - "add_text_embeds", - "add_time_ids", - "negative_pooled_prompt_embeds", - "add_neg_time_ids", - ] def __init__( self, @@ -544,7 +535,6 @@ def check_inputs( negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, - callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") @@ -561,13 +551,6 @@ def check_inputs( f" {type(callback_steps)}." ) - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -1177,7 +1160,6 @@ def __call__( negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, - callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale From a7bed83ffe8b71d2207eb00281d6848ff02a866e Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sat, 23 Mar 2024 05:09:56 +0530 Subject: [PATCH 06/12] .pop() -> .get() --- .../differential_diffusion/README.md | 5 +++++ .../differential_diffusion_mixin.py | 11 +++++------ 2 files changed, 10 insertions(+), 6 deletions(-) create mode 100644 examples/research_projects/differential_diffusion/README.md diff --git a/examples/research_projects/differential_diffusion/README.md b/examples/research_projects/differential_diffusion/README.md new file mode 100644 index 000000000000..26aad060279a --- /dev/null +++ b/examples/research_projects/differential_diffusion/README.md @@ -0,0 +1,5 @@ +# Differential Diffusion + +- Paper: https://differential-diffusion.github.io/paper.pdf +- Project site: https://differential-diffusion.github.io/ +- Code: https://github.com/exx8/differential-diffusion diff --git a/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py b/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py index 006e620d4c95..09690d38a049 100644 --- a/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py +++ b/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py @@ -27,7 +27,6 @@ def __init__(self): if not all(kwarg in prepare_latents_possible_kwargs for kwarg in prepare_latents_required_kwargs): raise ValueError(f"`prepare_latents` must have the following arguments: {prepare_latents_required_kwargs}") - @torch.no_grad() def _inference(self, map: torch.FloatTensor, **kwargs): if map is None: raise ValueError("`map` must be provided for differential diffusion.") @@ -41,15 +40,15 @@ def _inference(self, map: torch.FloatTensor, **kwargs): set(callback_on_step_end_tensor_inputs_required + original_callback_on_step_end_tensor_inputs) ) - image = kwargs.pop("image", _get_default_value(self.__call__, "image")) - num_inference_steps = kwargs.pop( + image = kwargs.get("image", _get_default_value(self.__call__, "image")) + num_inference_steps = kwargs.get( "num_inference_steps", _get_default_value(self.__call__, "num_inference_steps") ) - num_images_per_prompt = kwargs.pop( + num_images_per_prompt = kwargs.get( "num_images_per_prompt", _get_default_value(self.__call__, "num_images_per_prompt") ) - generator = kwargs.pop("generator", _get_default_value(self.__call__, "generator")) - denoising_start = kwargs.pop("denoising_start", _get_default_value(self.__call__, "denoising_start")) + generator = kwargs.get("generator", _get_default_value(self.__call__, "generator")) + denoising_start = kwargs.get("denoising_start", _get_default_value(self.__call__, "denoising_start")) callback_before_step_begin = kwargs.pop( "callback_before_step_begin", _get_default_value(self.__call__, "callback_before_step_begin") ) From 48b66db538040b48b5a138adda2ef27702087c8d Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 28 Mar 2024 01:29:37 +0530 Subject: [PATCH 07/12] refactor --- .../differential_diffusion_mixin.py | 10 ++------ .../pipeline_stable_diffusion_xl_img2img.py | 24 +------------------ 2 files changed, 3 insertions(+), 31 deletions(-) diff --git a/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py b/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py index 09690d38a049..1d78fe28fa93 100644 --- a/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py +++ b/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py @@ -49,9 +49,6 @@ def _inference(self, map: torch.FloatTensor, **kwargs): ) generator = kwargs.get("generator", _get_default_value(self.__call__, "generator")) denoising_start = kwargs.get("denoising_start", _get_default_value(self.__call__, "denoising_start")) - callback_before_step_begin = kwargs.pop( - "callback_before_step_begin", _get_default_value(self.__call__, "callback_before_step_begin") - ) def callback(pipe, i: int, t: int, callback_kwargs: Dict[str, Any]): nonlocal original_with_noise, thresholds, masks @@ -61,7 +58,7 @@ def callback(pipe, i: int, t: int, callback_kwargs: Dict[str, Any]): prompt_embeds = callback_kwargs.get("prompt_embeds") latents = callback_kwargs.get("latents") - if i < 0: + if i == 0: original_with_noise = self.prepare_latents( image=image, timestep=timesteps, @@ -77,8 +74,6 @@ def callback(pipe, i: int, t: int, callback_kwargs: Dict[str, Any]): if denoising_start is None: latents = original_with_noise[:1] - elif i == 0: - pass else: mask = masks[i].unsqueeze(0) mask = mask.to(latents.dtype) @@ -87,7 +82,7 @@ def callback(pipe, i: int, t: int, callback_kwargs: Dict[str, Any]): callback_results = {} - if original_callback_on_step_end is not None and (i >= 0 or callback_before_step_begin): + if original_callback_on_step_end is not None: callback_kwargs["latents"] = latents result = original_callback_on_step_end(pipe, i, t, callback_kwargs) callback_results.update(result) @@ -102,6 +97,5 @@ def callback(pipe, i: int, t: int, callback_kwargs: Dict[str, Any]): return self.__call__( callback_on_step_end=callback, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, - callback_before_step_begin=True, **kwargs, ) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 1e305c112acf..3e29e3e46d62 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -966,7 +966,6 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - callback_before_step_begin: bool = False, **kwargs, ): r""" @@ -1049,8 +1048,7 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. - ip_adapter_image: (`PipelineImageInput`, *optional*): - Optional image input to work with IP Adapters. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding @@ -1120,8 +1118,6 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - callback_before_step_begin (`bool`, *optional*, defaults to `False`): - If `True`, the `callback_on_step_end` function will be called before the denoising step begins. Examples: @@ -1322,24 +1318,6 @@ def denoising_value_valid(dnv): ).to(device=device, dtype=latents.dtype) self._num_timesteps = len(timesteps) - - if callback_before_step_begin: - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, -1, timesteps, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) - negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds - ) - add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) - with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: From 48435945b1af7771d2d2d0e6a8bc1b0db2f5b9bf Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 28 Mar 2024 01:34:45 +0530 Subject: [PATCH 08/12] add sd example for testing --- .../differential_diffusion_mixin.py | 9 ++- .../pipeline_stable_diffusion_sd.py | 61 +++++++++++++++++++ 2 files changed, 68 insertions(+), 2 deletions(-) create mode 100644 examples/research_projects/differential_diffusion/pipeline_stable_diffusion_sd.py diff --git a/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py b/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py index 1d78fe28fa93..5eb25fef2817 100644 --- a/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py +++ b/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py @@ -27,6 +27,8 @@ def __init__(self): if not all(kwarg in prepare_latents_possible_kwargs for kwarg in prepare_latents_required_kwargs): raise ValueError(f"`prepare_latents` must have the following arguments: {prepare_latents_required_kwargs}") + self._is_sdxl = hasattr(self, "text_encoder_2") + def _inference(self, map: torch.FloatTensor, **kwargs): if map is None: raise ValueError("`map` must be provided for differential diffusion.") @@ -48,7 +50,7 @@ def _inference(self, map: torch.FloatTensor, **kwargs): "num_images_per_prompt", _get_default_value(self.__call__, "num_images_per_prompt") ) generator = kwargs.get("generator", _get_default_value(self.__call__, "generator")) - denoising_start = kwargs.get("denoising_start", _get_default_value(self.__call__, "denoising_start")) + denoising_start = kwargs.get("denoising_start", _get_default_value(self.__call__, "denoising_start")) if self._is_sdxl else None def callback(pipe, i: int, t: int, callback_kwargs: Dict[str, Any]): nonlocal original_with_noise, thresholds, masks @@ -70,7 +72,10 @@ def callback(pipe, i: int, t: int, callback_kwargs: Dict[str, Any]): ) thresholds = torch.arange(num_inference_steps, dtype=map.dtype) / num_inference_steps thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(prompt_embeds.device) - masks = map > (thresholds + (denoising_start or 0)) + if self._is_sdxl: + masks = map > (thresholds + (denoising_start or 0)) + else: + masks = map > thresholds if denoising_start is None: latents = original_with_noise[:1] diff --git a/examples/research_projects/differential_diffusion/pipeline_stable_diffusion_sd.py b/examples/research_projects/differential_diffusion/pipeline_stable_diffusion_sd.py new file mode 100644 index 000000000000..28289c47853c --- /dev/null +++ b/examples/research_projects/differential_diffusion/pipeline_stable_diffusion_sd.py @@ -0,0 +1,61 @@ +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from differential_diffusion_mixin import DifferentialDiffusionMixin + +from diffusers import StableDiffusionImg2ImgPipeline +from diffusers.image_processor import PipelineImageInput + + +class DifferentialDiffusionSDPipeline(StableDiffusionImg2ImgPipeline, DifferentialDiffusionMixin): + @torch.no_grad() + def inference( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + timesteps: List[int] = None, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: int = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + # Differential Diffusion specific + map: torch.FloatTensor = None, + **kwargs, + ): + return self._inference( + prompt=prompt, + image=image, + strength=strength, + num_inference_steps=num_inference_steps, + timesteps=timesteps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ip_adapter_image=ip_adapter_image, + ip_adapter_image_embeds=ip_adapter_image_embeds, + output_type=output_type, + return_dict=return_dict, + cross_attention_kwargs=cross_attention_kwargs, + clip_skip=clip_skip, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + map=map, + **kwargs, + ) From 30f0cab5adcaafd5bad2bacde286b6e42943ce20 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 28 Mar 2024 01:45:24 +0530 Subject: [PATCH 09/12] allow all locals to be captured in callback --- .../differential_diffusion_mixin.py | 6 +++++- .../pipeline_stable_diffusion_sd.py | 2 +- .../pipeline_stable_diffusion_img2img.py | 9 --------- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py b/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py index 5eb25fef2817..cbefd5360f46 100644 --- a/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py +++ b/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py @@ -50,7 +50,11 @@ def _inference(self, map: torch.FloatTensor, **kwargs): "num_images_per_prompt", _get_default_value(self.__call__, "num_images_per_prompt") ) generator = kwargs.get("generator", _get_default_value(self.__call__, "generator")) - denoising_start = kwargs.get("denoising_start", _get_default_value(self.__call__, "denoising_start")) if self._is_sdxl else None + denoising_start = ( + kwargs.get("denoising_start", _get_default_value(self.__call__, "denoising_start")) + if self._is_sdxl + else None + ) def callback(pipe, i: int, t: int, callback_kwargs: Dict[str, Any]): nonlocal original_with_noise, thresholds, masks diff --git a/examples/research_projects/differential_diffusion/pipeline_stable_diffusion_sd.py b/examples/research_projects/differential_diffusion/pipeline_stable_diffusion_sd.py index 28289c47853c..6e6702681e7c 100644 --- a/examples/research_projects/differential_diffusion/pipeline_stable_diffusion_sd.py +++ b/examples/research_projects/differential_diffusion/pipeline_stable_diffusion_sd.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from differential_diffusion_mixin import DifferentialDiffusionMixin diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 540eed6ebd56..c79d4be7c679 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -198,7 +198,6 @@ class StableDiffusionImg2ImgPipeline( model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, @@ -639,7 +638,6 @@ def check_inputs( negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, - callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") @@ -650,12 +648,6 @@ def check_inputs( f" {type(callback_steps)}." ) - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -957,7 +949,6 @@ def __call__( negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, - callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale From 86fc3f3ab7cd5d17a24b79300b15a26f370328f7 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 28 Mar 2024 01:46:27 +0530 Subject: [PATCH 10/12] rename --- ...able_diffusion_sd.py => pipeline_differential_diffusion_sd.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/research_projects/differential_diffusion/{pipeline_stable_diffusion_sd.py => pipeline_differential_diffusion_sd.py} (100%) diff --git a/examples/research_projects/differential_diffusion/pipeline_stable_diffusion_sd.py b/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sd.py similarity index 100% rename from examples/research_projects/differential_diffusion/pipeline_stable_diffusion_sd.py rename to examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sd.py From 6ced1110bb8cba548a0b67ea4bf093d08133d92f Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 28 Mar 2024 03:02:06 +0530 Subject: [PATCH 11/12] fix --- .../differential_diffusion_mixin.py | 32 +++++++++++++++---- .../pipeline_differential_diffusion_sd.py | 2 ++ .../pipeline_differential_diffusion_sdxl.py | 2 ++ 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py b/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py index cbefd5360f46..11de08bcc874 100644 --- a/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py +++ b/examples/research_projects/differential_diffusion/differential_diffusion_mixin.py @@ -2,6 +2,9 @@ from typing import Any, Dict import torch +from torchvision import transforms + +from diffusers.image_processor import PipelineImageInput def _get_default_value(func, arg_name): @@ -27,22 +30,32 @@ def __init__(self): if not all(kwarg in prepare_latents_possible_kwargs for kwarg in prepare_latents_required_kwargs): raise ValueError(f"`prepare_latents` must have the following arguments: {prepare_latents_required_kwargs}") - self._is_sdxl = hasattr(self, "text_encoder_2") - - def _inference(self, map: torch.FloatTensor, **kwargs): + def _inference(self, original_image: PipelineImageInput, map: torch.FloatTensor, **kwargs): + if original_image is None: + raise ValueError("`original_image` must be provided for differential diffusion.") if map is None: raise ValueError("`map` must be provided for differential diffusion.") + self._is_sdxl = hasattr(self, "text_encoder_2") + kwargs["num_images_per_prompt"] = 1 + original_with_noise = thresholds = masks = None original_callback_on_step_end = kwargs.pop("callback_on_step_end", None) original_callback_on_step_end_tensor_inputs = kwargs.pop("callback_on_step_end_tensor_inputs", []) - callback_on_step_end_tensor_inputs_required = ["timesteps", "batch_size", "prompt_embeds", "device", "latents"] + callback_on_step_end_tensor_inputs_required = [ + "timesteps", + "batch_size", + "prompt_embeds", + "device", + "latents", + "height", + "width", + ] callback_on_step_end_tensor_inputs = list( set(callback_on_step_end_tensor_inputs_required + original_callback_on_step_end_tensor_inputs) ) - image = kwargs.get("image", _get_default_value(self.__call__, "image")) num_inference_steps = kwargs.get( "num_inference_steps", _get_default_value(self.__call__, "num_inference_steps") ) @@ -57,16 +70,21 @@ def _inference(self, map: torch.FloatTensor, **kwargs): ) def callback(pipe, i: int, t: int, callback_kwargs: Dict[str, Any]): - nonlocal original_with_noise, thresholds, masks + nonlocal original_with_noise, thresholds, masks, map + height = callback_kwargs.get("height") + width = callback_kwargs.get("width") timesteps = callback_kwargs.get("timesteps") batch_size = callback_kwargs.get("batch_size") prompt_embeds = callback_kwargs.get("prompt_embeds") latents = callback_kwargs.get("latents") if i == 0: + map = transforms.Resize( + (height // pipe.vae_scale_factor, width // pipe.vae_scale_factor), antialias=None + )(map) original_with_noise = self.prepare_latents( - image=image, + image=original_image, timestep=timesteps, batch_size=batch_size, num_images_per_prompt=num_images_per_prompt, diff --git a/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sd.py b/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sd.py index 6e6702681e7c..f8c513a6e7aa 100644 --- a/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sd.py +++ b/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sd.py @@ -32,6 +32,7 @@ def inference( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], # Differential Diffusion specific + original_image: PipelineImageInput = None, map: torch.FloatTensor = None, **kwargs, ): @@ -56,6 +57,7 @@ def inference( clip_skip=clip_skip, callback_on_step_end=callback_on_step_end, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + original_image=original_image, map=map, **kwargs, ) diff --git a/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sdxl.py b/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sdxl.py index c331bb328316..aa0852acfc7a 100644 --- a/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sdxl.py +++ b/examples/research_projects/differential_diffusion/pipeline_differential_diffusion_sdxl.py @@ -48,6 +48,7 @@ def inference( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], # Differential Diffusion specific + original_image: PipelineImageInput = None, map: torch.FloatTensor = None, **kwargs, ): @@ -88,6 +89,7 @@ def inference( clip_skip=clip_skip, callback_on_step_end=callback_on_step_end, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + original_image=original_image, map=map, **kwargs, ) From ea2ecbc9ae8a4762a72a86d73d28f108d3a80356 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 28 Mar 2024 04:10:05 +0530 Subject: [PATCH 12/12] update readme --- .../differential_diffusion/README.md | 107 ++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/examples/research_projects/differential_diffusion/README.md b/examples/research_projects/differential_diffusion/README.md index 26aad060279a..8000d28b7d39 100644 --- a/examples/research_projects/differential_diffusion/README.md +++ b/examples/research_projects/differential_diffusion/README.md @@ -1,5 +1,112 @@ # Differential Diffusion +> Diffusion models have revolutionized image generation and editing, producing state-of-the-art results in conditioned and unconditioned image synthesis. While current techniques enable user control over the degree of change in an image edit, the controllability is limited to global changes over an entire edited region. This paper introduces a novel framework that enables customization of the amount of change per pixel or per image region. Our framework can be integrated into any existing diffusion model, enhancing it with this capability. Such granular control on the quantity of change opens up a diverse array of new editing capabilities, such as control of the extent to which individual objects are modified, or the ability to introduce gradual spatial changes. Furthermore, we showcase the framework's effectiveness in soft-inpainting—the completion of portions of an image while subtly adjusting the surrounding areas to ensure seamless integration. Additionally, we introduce a new tool for exploring the effects of different change quantities. Our framework operates solely during inference, requiring no model training or fine-tuning. We demonstrate our method with the current open state-of-the-art models, and validate it via both quantitative and qualitative comparisons, and a user study. + - Paper: https://differential-diffusion.github.io/paper.pdf - Project site: https://differential-diffusion.github.io/ - Code: https://github.com/exx8/differential-diffusion + +### Usage + +```py +import torch +from torchvision import transforms +from PIL import Image + +from diffusers.schedulers import DEISMultistepScheduler, DPMSolverSDEScheduler +from pipeline_differential_diffusion_sdxl import DifferentialDiffusionSDXLPipeline + + +def preprocess_image(image, device="cuda"): + image = image.convert("RGB") + image = transforms.CenterCrop((image.size[1] // 64 * 64, image.size[0] // 64 * 64))(image) + image = transforms.ToTensor()(image) + image = image * 2 - 1 + image = image.unsqueeze(0).to(device) + return image + + +def preprocess_map(map, height, width, device="cuda"): + map = map.convert("L") + map = transforms.CenterCrop((map.size[1] // 64 * 64, map.size[0] // 64 * 64))(map) + map = transforms.ToTensor()(map) + map = map.to(device) + return map + + +model_id = "stabilityai/stable-diffusion-xl-base-1.0" +pipe = DifferentialDiffusionSDXLPipeline.from_pretrained( + model_id, + torch_dtype=torch.float16, + variant="fp16", + cache_dir="/workspace", +).to("cuda") +refiner = DifferentialDiffusionSDXLPipeline.from_pretrained( + model_id, + text_encoder_2=pipe.text_encoder_2, + vae=pipe.vae, + torch_dtype=torch.float16, + variant="fp16", + cache_dir="/workspace", +).to("cuda") + +# enable memory savings +pipe.enable_vae_slicing() +refiner.enable_vae_slicing() + +image = Image.open("image.png") +map = Image.open("mask.png") + +processed_image = preprocess_image(image) +processed_map = preprocess_map(map, processed_image.shape[2], processed_image.shape[3]) + +prompt = "a crow sitting on a branch, photorealistic, high quality" +negative_prompt = "unrealistic, logo, jpeg artifacts, low quality, worst quality, cartoon, animated" +generator = torch.Generator().manual_seed(42) +guidance_scale = 24 +strength = 1 +denoise_boundary = 0.8 +num_inference_steps = 50 + +# If you want to use with refiner +latent = pipe.inference( + prompt=prompt, + negative_prompt=negative_prompt, + original_image=processed_image, + image=processed_image, + map=processed_map, + strength=strength, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + denoising_end=denoise_boundary, + output_type="latent", + generator=generator, +).images[0] +output = pipe.inference( + prompt=prompt, + negative_prompt=negative_prompt, + original_image=processed_image, + image=latent, + map=processed_map, + strength=strength, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + denoising_start=denoise_boundary, + generator=generator, +).images[0] + +# If you want to use without refiner +output = pipe.inference( + prompt=prompt, + negative_prompt=negative_prompt, + original_image=processed_image, + image=processed_image, + map=processed_map, + strength=strength, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, +).images[0] + +output.save("result.png") +```