From da5c09237a918c06e24e30fd53bc0bfe070af977 Mon Sep 17 00:00:00 2001 From: JoaoLages Date: Tue, 13 Sep 2022 20:43:49 +0100 Subject: [PATCH 1/6] =?UTF-8?q?add=20awesome=20features=20=F0=9F=8C=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../pipelines/stable_diffusion/__init__.py | 3 + .../pipeline_stable_diffusion.py | 82 ++++++++++++++++--- 2 files changed, 72 insertions(+), 13 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 5ffda93f1721..7b0009d2fc90 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -21,10 +21,13 @@ class StableDiffusionPipelineOutput(BaseOutput): nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content. + latents (`List[torch.Tensor]`, *optional*, returned when `output_latents=True` is passed) + List (one element for each diffusion step) of `torch.Tensor` of shape `(batch_size, in_channels, height // 8, width // 8)` """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: List[bool] + latents: Optional[List[torch.Tensor]] = None if is_transformers_available(): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index f02fa114a8e1..3af8c233c397 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -90,10 +90,29 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) - @torch.no_grad() + def gradient_checkpointing_enable(self) -> None: + """ + Activates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + self.pipe.text_encoder.gradient_checkpointing_enable() + # TODO: activate gradient checkpointing for self.unet and self.vae + + def gradient_checkpointing_disable(self) -> None: + """ + Deactivates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + self.pipe.text_encoder.gradient_checkpointing_disable() + # TODO: disable gradient checkpointing for self.unet and self.vae + def __call__( self, - prompt: Union[str, List[str]], + prompt: Union[str, List[str], torch.Tensor], height: Optional[int] = 512, width: Optional[int] = 512, num_inference_steps: Optional[int] = 50, @@ -103,13 +122,16 @@ def __call__( latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + output_latents: bool = False, + run_safety_checker: bool = True, + enable_grad: bool = False, **kwargs, ): r""" Function invoked when calling the pipeline for generation. Args: - prompt (`str` or `List[str]`): + prompt (`str`, `List[str]` or `torch.Tensor`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. @@ -140,6 +162,13 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. + output_latents (`bool`, *optional*, defaults to `False`): + Whether or not to return the latents from all the diffusion steps. See `latents` under returned tensors + for more details. + run_safety_checker (`bool`, *optional*, defaults to `True`): + Whether or not to return run the safety checker in the final generated image. + enable_grad (`bool`, *optional*, defaults to `False`): + Whether or not to enable gradient calculation during diffusion process. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: @@ -161,10 +190,26 @@ def __call__( device = "cuda" if torch.cuda.is_available() else "cpu" self.to(device) + # enable/disable grad + was_grad_enabled = torch.is_grad_enabled() + torch.set_grad_enabled(enable_grad) + if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) + elif torch.is_tensor(prompt): + if len(prompt.shape) == 2: + # Add batch dimension + prompt = prompt.unsqueeze(0) + + if len(prompt.shape) != 3: + raise ValueError( + f"If `prompt` is of type `torch.Tensor`, it is expected to have a 2 dimensions " + f"(sequence len, embedding dim) or 3 dimensions (batch size, sequence len, embedding dim), " + f"but found tensor with shape {prompt.shape}" + ) + batch_size = prompt.shape[0] else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") @@ -172,14 +217,17 @@ def __call__( raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # get prompt text embeddings - text_input = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + if torch.is_tensor(prompt): + text_embeddings = prompt + else: + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -187,7 +235,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: - max_length = text_input.input_ids.shape[-1] + max_length = text_embeddings.shape[-2] uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) @@ -237,6 +285,7 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = eta + all_latents = [latents] if output_latents else None for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -259,6 +308,10 @@ def __call__( else: latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + if output_latents: + # save latents from all diffusion steps + all_latents.append(latents) + # scale and decode the image latents with vae latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample @@ -276,4 +329,7 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + # reset + torch.set_grad_enabled(was_grad_enabled) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, latents=all_latents) From eee018dec79a4ed459cc4fa2b9fd81f704dadc7a Mon Sep 17 00:00:00 2001 From: JoaoLages Date: Tue, 13 Sep 2022 20:47:49 +0100 Subject: [PATCH 2/6] also add latents to tuple --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 3af8c233c397..4fa6c439768e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -327,7 +327,7 @@ def __call__( image = self.numpy_to_pil(image) if not return_dict: - return (image, has_nsfw_concept) + return (image, has_nsfw_concept, latents) # reset torch.set_grad_enabled(was_grad_enabled) From 275ba790b032e3911318ea2e30a340de5c003066 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lages?= Date: Tue, 13 Sep 2022 21:10:44 +0100 Subject: [PATCH 3/6] Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 4fa6c439768e..9d03f410c1d5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -327,7 +327,7 @@ def __call__( image = self.numpy_to_pil(image) if not return_dict: - return (image, has_nsfw_concept, latents) + return (image, has_nsfw_concept, all_latents) # reset torch.set_grad_enabled(was_grad_enabled) From 259c149f36456e0e5dc64ba428b772e54fd9b053 Mon Sep 17 00:00:00 2001 From: JoaoLages Date: Wed, 14 Sep 2022 08:40:33 +0100 Subject: [PATCH 4/6] add PR suggestions --- src/diffusers/pipelines/stable_diffusion/__init__.py | 2 +- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 7b0009d2fc90..f1f7e7f274d3 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -21,7 +21,7 @@ class StableDiffusionPipelineOutput(BaseOutput): nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content. - latents (`List[torch.Tensor]`, *optional*, returned when `output_latents=True` is passed) + latents (`List[torch.Tensor]`, *optional*, returned when `output_latents=True` and `return_dict=True` is passed) List (one element for each diffusion step) of `torch.Tensor` of shape `(batch_size, in_channels, height // 8, width // 8)` """ diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 9d03f410c1d5..3ddea57e7be3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -132,7 +132,8 @@ def __call__( Args: prompt (`str`, `List[str]` or `torch.Tensor`): - The prompt or prompts to guide the image generation. + The prompt or prompts to guide the image generation. If a `torch.Tensor` is provided, it should + have the shape (sequence len, embedding dim) or (batch size, sequence len, embedding dim). height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): @@ -327,7 +328,7 @@ def __call__( image = self.numpy_to_pil(image) if not return_dict: - return (image, has_nsfw_concept, all_latents) + return (image, has_nsfw_concept) # reset torch.set_grad_enabled(was_grad_enabled) From 12ca96990b0e6d7ef34a0c4bbcd9819ffe2752bd Mon Sep 17 00:00:00 2001 From: JoaoLages Date: Fri, 16 Sep 2022 11:10:01 +0100 Subject: [PATCH 5/6] remove run_safety_checker option --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 3ddea57e7be3..ba85d01c9d1e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -123,7 +123,6 @@ def __call__( output_type: Optional[str] = "pil", return_dict: bool = True, output_latents: bool = False, - run_safety_checker: bool = True, enable_grad: bool = False, **kwargs, ): @@ -166,8 +165,6 @@ def __call__( output_latents (`bool`, *optional*, defaults to `False`): Whether or not to return the latents from all the diffusion steps. See `latents` under returned tensors for more details. - run_safety_checker (`bool`, *optional*, defaults to `True`): - Whether or not to return run the safety checker in the final generated image. enable_grad (`bool`, *optional*, defaults to `False`): Whether or not to enable gradient calculation during diffusion process. From 01ae240d56353a5c4e69b21a239a7543d74da1c8 Mon Sep 17 00:00:00 2001 From: JoaoLages Date: Fri, 16 Sep 2022 14:28:35 +0100 Subject: [PATCH 6/6] remove gradient checkpointing --- .../pipeline_stable_diffusion.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index ba85d01c9d1e..92e91423b2c7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -90,26 +90,6 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) - def gradient_checkpointing_enable(self) -> None: - """ - Activates gradient checkpointing for the current model. - - Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint - activations". - """ - self.pipe.text_encoder.gradient_checkpointing_enable() - # TODO: activate gradient checkpointing for self.unet and self.vae - - def gradient_checkpointing_disable(self) -> None: - """ - Deactivates gradient checkpointing for the current model. - - Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint - activations". - """ - self.pipe.text_encoder.gradient_checkpointing_disable() - # TODO: disable gradient checkpointing for self.unet and self.vae - def __call__( self, prompt: Union[str, List[str], torch.Tensor],