From fa6b508a467abcdf0b37e4554960f8de2e3ed917 Mon Sep 17 00:00:00 2001 From: Yukun Huang <563837686@qq.com> Date: Sun, 27 Aug 2023 04:27:56 +0800 Subject: [PATCH 1/2] Fix potential type conversion errors in SDXL pipelines --- examples/community/lpw_stable_diffusion_xl.py | 2 +- examples/community/stable_diffusion_xl_reference.py | 4 ++-- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 2 +- .../stable_diffusion/pipeline_stable_diffusion_upscale.py | 2 +- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 2 +- .../pipeline_stable_diffusion_xl_img2img.py | 2 +- .../pipeline_stable_diffusion_xl_inpaint.py | 2 +- .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 2 +- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/community/lpw_stable_diffusion_xl.py b/examples/community/lpw_stable_diffusion_xl.py index 7c825c8dbabb..5ffb15cabf43 100644 --- a/examples/community/lpw_stable_diffusion_xl.py +++ b/examples/community/lpw_stable_diffusion_xl.py @@ -1187,9 +1187,9 @@ def __call__( # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) if not output_type == "latent": + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents diff --git a/examples/community/stable_diffusion_xl_reference.py b/examples/community/stable_diffusion_xl_reference.py index 45cd4e0a2e0e..76823b5f5fe6 100644 --- a/examples/community/stable_diffusion_xl_reference.py +++ b/examples/community/stable_diffusion_xl_reference.py @@ -772,12 +772,12 @@ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb= if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # make sure the VAE is in float32 mode, as it overflows in float16 + # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) if not output_type == "latent": + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 9a8d5edc7e07..b7599cb413ca 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1177,9 +1177,9 @@ def __call__( # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) if not output_type == "latent": + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 8f0776361fc3..4acb65df6e78 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -724,10 +724,10 @@ def __call__( # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) # post-processing if not output_type == "latent": + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype) else: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 11e575d68269..e689e01ff9ab 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -873,9 +873,9 @@ def __call__( # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) if not output_type == "latent": + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index ff51f8765e4a..7c2f18f215b2 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -1030,9 +1030,9 @@ def denoising_value_valid(dnv): # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) if not output_type == "latent": + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index eecbdc7e669e..cf9acb1d0212 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -1336,9 +1336,9 @@ def denoising_value_valid(dnv): # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) if not output_type == "latent": + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: return StableDiffusionXLPipelineOutput(images=latents) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 83345454443f..fff95216605d 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -911,9 +911,9 @@ def __call__( # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) if not output_type == "latent": + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents From da1846516c974671377c6268ea257b2692b3385f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 28 Aug 2023 07:35:47 +0000 Subject: [PATCH 2/2] make sure vae stays in fp16 --- examples/community/lpw_stable_diffusion_xl.py | 16 +++++++++++----- .../stable_diffusion_xl_reference.py | 16 +++++++++++----- .../controlnet/pipeline_controlnet_sd_xl.py | 16 +++++++++++----- .../pipeline_stable_diffusion_upscale.py | 19 ++++++++++++------- .../pipeline_stable_diffusion_xl.py | 16 +++++++++++----- .../pipeline_stable_diffusion_xl_img2img.py | 16 +++++++++++----- .../pipeline_stable_diffusion_xl_inpaint.py | 16 +++++++++++----- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 16 +++++++++++----- 8 files changed, 89 insertions(+), 42 deletions(-) diff --git a/examples/community/lpw_stable_diffusion_xl.py b/examples/community/lpw_stable_diffusion_xl.py index 5ffb15cabf43..abfbfb5aa1c1 100644 --- a/examples/community/lpw_stable_diffusion_xl.py +++ b/examples/community/lpw_stable_diffusion_xl.py @@ -1184,13 +1184,19 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: - self.upcast_vae() - if not output_type == "latent": - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) else: image = latents return StableDiffusionXLPipelineOutput(images=image) diff --git a/examples/community/stable_diffusion_xl_reference.py b/examples/community/stable_diffusion_xl_reference.py index 76823b5f5fe6..b47c962701b6 100644 --- a/examples/community/stable_diffusion_xl_reference.py +++ b/examples/community/stable_diffusion_xl_reference.py @@ -772,13 +772,19 @@ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb= if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: - self.upcast_vae() - if not output_type == "latent": - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) else: image = latents return StableDiffusionXLPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index b7599cb413ca..730e522b54a6 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1174,13 +1174,19 @@ def __call__( self.controlnet.to("cpu") torch.cuda.empty_cache() - # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: - self.upcast_vae() - if not output_type == "latent": - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) else: image = latents return StableDiffusionXLPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 4acb65df6e78..2bfffba84f58 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -720,15 +720,20 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 10. Post-processing - # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: - self.upcast_vae() - - # post-processing if not output_type == "latent": - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index e689e01ff9ab..2d4ef87bdf79 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -870,13 +870,19 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: - self.upcast_vae() - if not output_type == "latent": - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) else: image = latents return StableDiffusionXLPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 7c2f18f215b2..aada52eecbcb 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -1027,13 +1027,19 @@ def denoising_value_valid(dnv): if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: - self.upcast_vae() - if not output_type == "latent": - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) else: image = latents return StableDiffusionXLPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index cf9acb1d0212..bc29cecdbdc9 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -1333,13 +1333,19 @@ def denoising_value_valid(dnv): if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: - self.upcast_vae() - if not output_type == "latent": - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) else: return StableDiffusionXLPipelineOutput(images=latents) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index fff95216605d..45c90b0d8d66 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -908,13 +908,19 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: - self.upcast_vae() - if not output_type == "latent": - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) else: image = latents return StableDiffusionXLPipelineOutput(images=image)