From 819d98843afe6ee902ee50e246e540ee459eadad Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 25 Aug 2023 11:17:34 +0530 Subject: [PATCH 01/10] add: support negative conditions. --- .../pipeline_stable_diffusion_xl.py | 43 ++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index bf6c625bb2b6..74252c952ff9 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -438,6 +438,7 @@ def check_inputs( negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, + negative_conditions=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -498,6 +499,30 @@ def check_inputs( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) + if negative_conditions is not None: + if not isinstance(negative_conditions, dict): + raise ValueError( + "`negative_conditions` should be provided as a dictionary. Refer to the docstrings to learn more about how the dictionary should be structured." + ) + else: + all_keys_present = all( + k in negative_conditions for k in {"original_size", "target_size", "coords_top_left"} + ) + if not all_keys_present: + raise ValueError( + f"When `negative_conditions` are provided, it's expected to have the following keys: `original_size`, `target_size`, and `coords_top_left`, but only the following keys were found:\n {list(negative_conditions.keys())}" + ) + else: + for k in negative_conditions: + current_condition = negative_conditions[k] + if not isinstance(current_condition, tuple): + raise ValueError(f"{k} in `negative_conditions` is expected to be a tuple.") + else: + if len(current_condition) != 2: + raise ValueError( + f"{k} in `negative_conditions` is expected to be a tuple of length 2." + ) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) @@ -582,6 +607,7 @@ def __call__( original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, + negative_conditions: Dict[str, Tuple[int, int]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -681,6 +707,11 @@ def __call__( For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_conditions (`Dict[str, Tuple[int, int]]`, *optional*, defaults to None): + Optional negative conditions to be provided to the UNet. Here's an example of how the dictionary should + be structured should it be provided: {"original_size": (512, 512), "crops_coords_top_left": (0, 0). + "targe_size": (1024, 1024)} For more information, refer to this issue thread: + https://github.com/huggingface/diffusers/issues/4208. Examples: @@ -709,6 +740,7 @@ def __call__( negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, + negative_conditions, ) # 2. Define call parameters @@ -776,11 +808,20 @@ def __call__( add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype ) + if negative_conditions is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_conditions["original_size"], + negative_conditions["crops_coords_top_left"], + negative_conditions["target_size"], + dtype=prompt_embeds.dtype, + ) + else: + negative_add_time_ids = add_time_ids if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) From e6e4e37719a39fc12af9bfd6bf99f066e54f22f1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 25 Aug 2023 11:35:41 +0530 Subject: [PATCH 02/10] fix: key --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 74252c952ff9..8f2ad37500f4 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -506,11 +506,11 @@ def check_inputs( ) else: all_keys_present = all( - k in negative_conditions for k in {"original_size", "target_size", "coords_top_left"} + k in negative_conditions for k in {"original_size", "target_size", "crops_coords_top_left"} ) if not all_keys_present: raise ValueError( - f"When `negative_conditions` are provided, it's expected to have the following keys: `original_size`, `target_size`, and `coords_top_left`, but only the following keys were found:\n {list(negative_conditions.keys())}" + f"When `negative_conditions` are provided, it's expected to have the following keys: `original_size`, `target_size`, and `crops_coords_top_left`, but only the following keys were found:\n {list(negative_conditions.keys())}" ) else: for k in negative_conditions: @@ -709,8 +709,8 @@ def __call__( section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_conditions (`Dict[str, Tuple[int, int]]`, *optional*, defaults to None): Optional negative conditions to be provided to the UNet. Here's an example of how the dictionary should - be structured should it be provided: {"original_size": (512, 512), "crops_coords_top_left": (0, 0). - "targe_size": (1024, 1024)} For more information, refer to this issue thread: + be structured should it be provided: {"original_size": (512, 512), "crops_coords_top_left": (0, 0), + "targe_size": (1024, 1024)}. For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. Examples: From 0e85ab42d2e5b77a65b4c963f8e62e2ef3436410 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 25 Aug 2023 11:54:21 +0530 Subject: [PATCH 03/10] add: tests --- .../test_stable_diffusion_xl.py | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 2d251a658658..86028f2a4ceb 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -128,7 +128,7 @@ def get_dummy_inputs(self, device, seed=0): "generator": generator, "num_inference_steps": 2, "guidance_scale": 5.0, - "output_type": "numpy", + "output_type": "np", } return inputs @@ -689,3 +689,24 @@ def test_stable_diffusion_xl_multi_prompts(self): # ensure the results are not equal assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 + + def test_stable_diffusion_xl_negative_conditions(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice_with_no_neg_cond = image[0, -3:, -3:, -1] + + negative_conditions = { + "original_size": (512, 512), + "target_size": (1024, 1024), + "crops_coords_top_left": (0, 0), + } + image = sd_pipe(**inputs, negative_conditions=negative_conditions).images + image_slice_with_neg_cond = image[0, -3:, -3:, -1] + + self.assertTrue(np.abs(image_slice_with_no_neg_cond - image_slice_with_neg_cond).max() > 1e-2) From b8e313e55699bf4a7b1944a3109695ef2bbad4b0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 25 Aug 2023 15:04:10 +0530 Subject: [PATCH 04/10] address PR feedback. --- .../pipeline_stable_diffusion_xl.py | 58 +++++++------------ .../test_stable_diffusion_xl.py | 12 ++-- 2 files changed, 28 insertions(+), 42 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 8f2ad37500f4..0d11ae548c5f 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -438,7 +438,6 @@ def check_inputs( negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, - negative_conditions=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -499,30 +498,6 @@ def check_inputs( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) - if negative_conditions is not None: - if not isinstance(negative_conditions, dict): - raise ValueError( - "`negative_conditions` should be provided as a dictionary. Refer to the docstrings to learn more about how the dictionary should be structured." - ) - else: - all_keys_present = all( - k in negative_conditions for k in {"original_size", "target_size", "crops_coords_top_left"} - ) - if not all_keys_present: - raise ValueError( - f"When `negative_conditions` are provided, it's expected to have the following keys: `original_size`, `target_size`, and `crops_coords_top_left`, but only the following keys were found:\n {list(negative_conditions.keys())}" - ) - else: - for k in negative_conditions: - current_condition = negative_conditions[k] - if not isinstance(current_condition, tuple): - raise ValueError(f"{k} in `negative_conditions` is expected to be a tuple.") - else: - if len(current_condition) != 2: - raise ValueError( - f"{k} in `negative_conditions` is expected to be a tuple of length 2." - ) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) @@ -607,7 +582,9 @@ def __call__( original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, - negative_conditions: Dict[str, Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -707,11 +684,21 @@ def __call__( For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - negative_conditions (`Dict[str, Tuple[int, int]]`, *optional*, defaults to None): - Optional negative conditions to be provided to the UNet. Here's an example of how the dictionary should - be structured should it be provided: {"original_size": (512, 512), "crops_coords_top_left": (0, 0), - "targe_size": (1024, 1024)}. For more information, refer to this issue thread: - https://github.com/huggingface/diffusers/issues/4208. + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. Examples: @@ -740,7 +727,6 @@ def __call__( negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, - negative_conditions, ) # 2. Define call parameters @@ -808,11 +794,11 @@ def __call__( add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype ) - if negative_conditions is not None: + if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( - negative_conditions["original_size"], - negative_conditions["crops_coords_top_left"], - negative_conditions["target_size"], + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, dtype=prompt_embeds.dtype, ) else: diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 86028f2a4ceb..205183b12669 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -701,12 +701,12 @@ def test_stable_diffusion_xl_negative_conditions(self): image = sd_pipe(**inputs).images image_slice_with_no_neg_cond = image[0, -3:, -3:, -1] - negative_conditions = { - "original_size": (512, 512), - "target_size": (1024, 1024), - "crops_coords_top_left": (0, 0), - } - image = sd_pipe(**inputs, negative_conditions=negative_conditions).images + image = sd_pipe( + **inputs, + negative_original_size=(512, 512), + negative_crops_coords_top_left=(0, 0), + negative_target_size=(1024, 1024), + ).images image_slice_with_neg_cond = image[0, -3:, -3:, -1] self.assertTrue(np.abs(image_slice_with_no_neg_cond - image_slice_with_neg_cond).max() > 1e-2) From e43247f969d53377c615ad55253f49d8cbdc1c16 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 25 Aug 2023 15:13:51 +0530 Subject: [PATCH 05/10] add documentation --- .../stable_diffusion/stable_diffusion_xl.md | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md index 8486641da2c4..95b7fbbfa8ec 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md @@ -23,6 +23,7 @@ The abstract of the paper is the following: - Stable Diffusion XL works especially well with images between 768 and 1024. - Stable Diffusion XL can pass a different prompt for each of the text encoders it was trained on as shown below. We can even pass different parts of the same prompt to the text encoders. - Stable Diffusion XL output image can be improved by making use of a refiner as shown below. +- One make use of `negative_original_size`, `negative_crops_coords_top_left`, and `negative_target_size` to influence the generation process. ### Available checkpoints: @@ -74,6 +75,37 @@ prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt=prompt).images[0] ``` +You can additionally pass "negative conditions" to steer the generation process like so: + +```python +from diffusers import StableDiffusionXLPipeline +import torch + +pipe = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True +) +pipe.to("cuda") + +prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" +image = pipe( + prompt=prompt, + negative_original_size=(512, 512), + negative_crops_coords_top_left=(0, 0), + negative_target_size=(1024, 1024), +).images[0] +``` + +Here is a comparative example that shows the influence of using three `negative_original_size`s of +(128, 128), (256, 256), and (512, 512) respectively: + +![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/negative_conditions.png) + + + +One can use these negative conditions in the other SDXL pipelines ([Image-To-Image](#image-to-image), [Inpainting](#inpainting), [ControlNet](../controlnet_sdxl.md)) too! + + + ### Image-to-image You can use SDXL as follows for *image-to-image*: From 8d6af88611f43921a632cbce90cc5b956abedb57 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 25 Aug 2023 15:31:14 +0530 Subject: [PATCH 06/10] add img2img support. --- .../pipeline_stable_diffusion_xl_img2img.py | 43 +++++++++++++++++-- .../test_stable_diffusion_xl_img2img.py | 27 ++++++++++++ 2 files changed, 67 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index d07405d45bfc..397c05bc42b7 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -592,14 +592,25 @@ def prepare_latents( return latents def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) - add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) else: add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim @@ -688,6 +699,9 @@ def __call__( original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, ): @@ -802,6 +816,21 @@ def __call__( For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. aesthetic_score (`float`, *optional*, defaults to 6.0): Used to simulate an aesthetic score of the generated image by influencing the positive text condition. Part of SDXL's micro-conditioning as explained in section 2.2 of @@ -906,6 +935,11 @@ def denoising_value_valid(dnv): target_size = target_size or (height, width) # 8. Prepare added time ids & embeddings + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + add_text_embeds = pooled_prompt_embeds add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, @@ -913,6 +947,9 @@ def denoising_value_valid(dnv): target_size, aesthetic_score, negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, dtype=prompt_embeds.dtype, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index 1e879151ac2f..7c7838b3e8c1 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -315,3 +315,30 @@ def test_stable_diffusion_xl_multi_prompts(self): # ensure the results are not equal assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 + + def test_stable_diffusion_xl_img2img_negative_conditions(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionXLImg2ImgPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice_with_no_neg_conditions = image[0, -3:, -3:, -1] + + image = sd_pipe( + **inputs, + negative_original_size=(512, 512), + negative_crops_coords_top_left=( + 0, + 0, + ), + negative_target_size=(1024, 1024), + ).images + image_slice_with_neg_conditions = image[0, -3:, -3:, -1] + + assert ( + np.abs(image_slice_with_no_neg_conditions.flatten() - image_slice_with_neg_conditions.flatten()).max() + > 1e-4 + ) From 8c65100420f405e20a6c34d14917e97771d7e54c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 25 Aug 2023 15:49:45 +0530 Subject: [PATCH 07/10] add inpainting support. --- .../pipeline_stable_diffusion_xl_inpaint.py | 43 +++++++++++++++++-- .../test_stable_diffusion_xl_img2img.py | 5 ++- .../test_stable_diffusion_xl_inpaint.py | 29 ++++++++++++- 3 files changed, 71 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index c480549aebb3..f9f3ed2f2631 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -793,14 +793,25 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) - add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) else: add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim @@ -885,6 +896,9 @@ def __call__( original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, ): @@ -1005,6 +1019,21 @@ def __call__( For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. aesthetic_score (`float`, *optional*, defaults to 6.0): Used to simulate an aesthetic score of the generated image by influencing the positive text condition. Part of SDXL's micro-conditioning as explained in section 2.2 of @@ -1172,6 +1201,11 @@ def denoising_value_valid(dnv): target_size = target_size or (height, width) # 10. Prepare added time ids & embeddings + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + add_text_embeds = pooled_prompt_embeds add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, @@ -1179,6 +1213,9 @@ def denoising_value_valid(dnv): target_size, aesthetic_score, negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, dtype=prompt_embeds.dtype, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index 7c7838b3e8c1..04cbb09f5196 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -138,7 +138,7 @@ def get_dummy_inputs(self, device, seed=0): "generator": generator, "num_inference_steps": 2, "guidance_scale": 5.0, - "output_type": "numpy", + "output_type": "np", "strength": 0.8, } return inputs @@ -319,7 +319,8 @@ def test_stable_diffusion_xl_multi_prompts(self): def test_stable_diffusion_xl_img2img_negative_conditions(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() - sd_pipe = StableDiffusionXLImg2ImgPipeline(**components) + + sd_pipe = self.pipeline_class(**components) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py index 05ce3f11973e..a4c378adbb23 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py @@ -139,7 +139,7 @@ def get_dummy_inputs(self, device, seed=0): "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, - "output_type": "numpy", + "output_type": "np", } return inputs @@ -471,3 +471,30 @@ def test_stable_diffusion_xl_multi_prompts(self): # ensure the results are not equal assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 + + def test_stable_diffusion_xl_img2img_negative_conditions(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = self.pipeline_class(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice_with_no_neg_conditions = image[0, -3:, -3:, -1] + + image = sd_pipe( + **inputs, + negative_original_size=(512, 512), + negative_crops_coords_top_left=( + 0, + 0, + ), + negative_target_size=(1024, 1024), + ).images + image_slice_with_neg_conditions = image[0, -3:, -3:, -1] + + assert ( + np.abs(image_slice_with_no_neg_conditions.flatten() - image_slice_with_neg_conditions.flatten()).max() + > 1e-4 + ) From 26797fbdfadc6abed8f83b7959abef629a1d5ce0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 25 Aug 2023 16:02:18 +0530 Subject: [PATCH 08/10] ad controlnet support --- .../controlnet/pipeline_controlnet_sd_xl.py | 31 ++++++++++++++++++- .../controlnet/test_controlnet_sdxl.py | 21 ++++++++++++- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 9866875425f7..98f00de64425 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -789,6 +789,9 @@ def __call__( original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -895,6 +898,22 @@ def __call__( For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + Examples: Returns: @@ -1058,10 +1077,20 @@ def __call__( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + ) + else: + negative_add_time_ids = add_time_ids + if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 7906467e3918..2496b9c48ab8 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -160,7 +160,7 @@ def get_dummy_inputs(self, device, seed=0): "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, - "output_type": "numpy", + "output_type": "np", "image": image, } @@ -680,6 +680,25 @@ def test_xformers_attention_forwardGenerator_pass(self): def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) + def test_negative_conditions(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + image = pipe(**inputs).images + image_slice_without_neg_cond = image[0, -3:, -3:, -1] + + image = pipe( + **inputs, + negative_original_size=(512, 512), + negative_crops_coords_top_left=(0, 0), + negative_target_size=(1024, 1024), + ).images + image_slice_with_neg_cond = image[0, -3:, -3:, -1] + + self.assertTrue(np.abs(image_slice_without_neg_cond - image_slice_with_neg_cond).max() > 1e-2) + @slow @require_torch_gpu From d84a3a12350264ff2d00ee5715d76258ee963af4 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 25 Aug 2023 16:03:14 +0530 Subject: [PATCH 09/10] Apply suggestions from code review --- .../en/api/pipelines/stable_diffusion/stable_diffusion_xl.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md index 95b7fbbfa8ec..45f5db1414f8 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md @@ -23,7 +23,7 @@ The abstract of the paper is the following: - Stable Diffusion XL works especially well with images between 768 and 1024. - Stable Diffusion XL can pass a different prompt for each of the text encoders it was trained on as shown below. We can even pass different parts of the same prompt to the text encoders. - Stable Diffusion XL output image can be improved by making use of a refiner as shown below. -- One make use of `negative_original_size`, `negative_crops_coords_top_left`, and `negative_target_size` to influence the generation process. +- One can make use of `negative_original_size`, `negative_crops_coords_top_left`, and `negative_target_size` to influence the generation process. ### Available checkpoints: From 250ee14ea930154955c976da556c956dc2745139 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 26 Aug 2023 08:53:42 +0530 Subject: [PATCH 10/10] modify wording in the doc. --- .../en/api/pipelines/stable_diffusion/stable_diffusion_xl.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md index 45f5db1414f8..f6585f819928 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md @@ -75,7 +75,7 @@ prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt=prompt).images[0] ``` -You can additionally pass "negative conditions" to steer the generation process like so: +You can additionally pass negative conditions about an image's size and position to avoid undesirable cropping behavior in the generated image, and improve image resolution. Let's take an example: ```python from diffusers import StableDiffusionXLPipeline