From 0f1da5c187ba1e473a619ae66308c41bfce40692 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 6 Jun 2023 17:51:59 +0000 Subject: [PATCH 1/3] refactor x4 upscaler --- .../pipeline_stable_diffusion_upscale.py | 58 ++++++++++++------- 1 file changed, 37 insertions(+), 21 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 6bb463a6a65f..a6095ba1f38d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -21,6 +21,7 @@ import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor @@ -34,6 +35,11 @@ def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -125,6 +131,8 @@ def __init__( watermarker=watermarker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") self.register_to_config(max_noise_level=max_noise_level) def enable_sequential_cpu_offload(self, gpu_id=0): @@ -432,14 +440,15 @@ def check_inputs( if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) + and not isinstance(image, np.ndarray) and not isinstance(image, list) ): raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}" + f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}" ) - # verify batch size of prompt and image are same if image is a list or tensor - if isinstance(image, list) or isinstance(image, torch.Tensor): + # verify batch size of prompt and image are same if image is a list or tensor or numpy array + if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray): if isinstance(prompt, str): batch_size = 1 else: @@ -483,7 +492,14 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, num_inference_steps: int = 75, guidance_scale: float = 9.0, noise_level: int = 20, @@ -506,7 +522,7 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch which will be upscaled. * num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -627,7 +643,7 @@ def __call__( ) # 4. Preprocess image - image = preprocess(image) + image = self.image_processor.preprocess(image) image = image.to(dtype=prompt_embeds.dtype, device=device) # 5. set timesteps @@ -722,25 +738,25 @@ def __call__( self.vae.decoder.mid_block.to(latents.dtype) else: latents = latents.float() - - # 11. Convert to PIL - if output_type == "pil": - image = self.decode_latents(latents) - + + # post-processing + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # 11. Apply watermark - if self.watermarker is not None: - image = self.watermarker.apply_watermark(image) - elif output_type == "pt": - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - has_nsfw_concept = None - else: - image = self.decode_latents(latents) - has_nsfw_concept = None + if output_type == "pil" and self.watermarker is not None: + image = self.watermarker.apply_watermark(image) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: From 5ef4544f0ab86d5a4d30e5ca209ea366c440bb9f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 6 Jun 2023 17:52:43 +0000 Subject: [PATCH 2/3] style --- .../stable_diffusion/pipeline_stable_diffusion_upscale.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index a6095ba1f38d..4c4f3998cb91 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -738,7 +738,7 @@ def __call__( self.vae.decoder.mid_block.to(latents.dtype) else: latents = latents.float() - + # post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] @@ -746,7 +746,7 @@ def __call__( else: image = latents has_nsfw_concept = None - + if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: @@ -754,7 +754,7 @@ def __call__( image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) - # 11. Apply watermark + # 11. Apply watermark if output_type == "pil" and self.watermarker is not None: image = self.watermarker.apply_watermark(image) From 5dead74592394961835d7843cdb6e5110bdf7313 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 6 Jun 2023 17:53:15 +0000 Subject: [PATCH 3/3] copies --- .../pipeline_stable_diffusion_latent_upscale.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py index e0fecf6d353f..d67a7f894886 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -33,6 +33,11 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image):