diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 45b26de284af..14e5c4ab7cd1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -697,15 +697,11 @@ def __call__( # 10. Post-processing # make sure the VAE is in float32 mode, as it overflows in float16 self.vae.to(dtype=torch.float32) - image = self.decode_latents(latents.float()) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() # 11. Convert to PIL # has_nsfw_concept = False if output_type == "pil": + image = self.decode_latents(latents.float()) image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype) image = self.numpy_to_pil(image) @@ -713,9 +709,18 @@ def __call__( # 11. Apply watermark if self.watermarker is not None: image = self.watermarker.apply_watermark(image) + elif output_type == "pt": + latents = 1 / self.vae.config.scaling_factor * latents.float() + image = self.vae.decode(latents).sample + has_nsfw_concept = None else: + image = self.decode_latents(latents.float()) has_nsfw_concept = None + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + if not return_dict: return (image, has_nsfw_concept)