From 7bfd0d9bc879aca20fb203e8315d5892576f50cc Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Sun, 6 Nov 2022 04:02:41 +0800 Subject: [PATCH 01/10] StableDiffusion: Decode latents separately to run larger batches --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 2 +- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 2 +- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 2 +- .../pipeline_stable_diffusion_inpaint_legacy.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 1ccc87804e68..a2f2b1877333 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -399,7 +399,7 @@ def __call__( callback(i, t, latents) latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample + image = torch.cat([self.vae.decode(latent).sample for latent in latents.split(1)]) image = (image / 2 + 0.5).clamp(0, 1) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 08b14b36be89..a45bca373078 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -428,7 +428,7 @@ def __call__( callback(i, t, latents) latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample + image = torch.cat([self.vae.decode(latent).sample for latent in latents.split(1)]) image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 34e8231c63ee..4b6f27a21324 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -465,7 +465,7 @@ def __call__( callback(i, t, latents) latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample + image = torch.cat([self.vae.decode(latent).sample for latent in latents.split(1)]) image = (image / 2 + 0.5).clamp(0, 1) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 5c06b74bfa38..3c78a598bee2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -405,7 +405,7 @@ def __call__( callback(i, t, latents) latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample + image = torch.cat([self.vae.decode(latent).sample for latent in latents.split(1)]) image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() From b7477334126b204f69ce90ee5b653022c2ae729e Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Fri, 18 Nov 2022 10:04:36 +0800 Subject: [PATCH 02/10] Move VAE sliced decode under enable_vae_sliced_decode and vae.enable_sliced_decode --- src/diffusers/models/vae.py | 31 ++++++++++++++++++- .../pipeline_stable_diffusion.py | 18 ++++++++++- .../pipeline_stable_diffusion_img2img.py | 2 +- .../pipeline_stable_diffusion_inpaint.py | 2 +- ...ipeline_stable_diffusion_inpaint_legacy.py | 2 +- 5 files changed, 50 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 30de343d08ee..d24c4db5b420 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -565,6 +565,7 @@ def __init__( self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + self.use_sliced_decode = False def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: h = self.encoder(x) @@ -576,7 +577,7 @@ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderK return AutoencoderKLOutput(latent_dist=posterior) - def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: z = self.post_quant_conv(z) dec = self.decoder(z) @@ -585,6 +586,34 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode return DecoderOutput(sample=dec) + def enable_sliced_decode(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices, to compute decoding in several + steps. This is useful to save some memory. + """ + self.use_sliced_decode = True + + def disable_sliced_decode(self): + r""" + Disable sliced VAE decoding. If `enable_sliced_decode` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.use_sliced_decode = False + + def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_sliced_decode: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index a2f2b1877333..9acbf1bcbae7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -158,6 +158,22 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) + def enable_sliced_vae_decode(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices, to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_sliced_decode() + + def disable_sliced_vae_decode(self): + r""" + Disable sliced VAE decoding. If `enable_sliced_decode` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_sliced_decode() + def enable_sequential_cpu_offload(self): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, @@ -399,7 +415,7 @@ def __call__( callback(i, t, latents) latents = 1 / 0.18215 * latents - image = torch.cat([self.vae.decode(latent).sample for latent in latents.split(1)]) + image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index a45bca373078..08b14b36be89 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -428,7 +428,7 @@ def __call__( callback(i, t, latents) latents = 1 / 0.18215 * latents - image = torch.cat([self.vae.decode(latent).sample for latent in latents.split(1)]) + image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 4b6f27a21324..34e8231c63ee 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -465,7 +465,7 @@ def __call__( callback(i, t, latents) latents = 1 / 0.18215 * latents - image = torch.cat([self.vae.decode(latent).sample for latent in latents.split(1)]) + image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 3c78a598bee2..5c06b74bfa38 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -405,7 +405,7 @@ def __call__( callback(i, t, latents) latents = 1 / 0.18215 * latents - image = torch.cat([self.vae.decode(latent).sample for latent in latents.split(1)]) + image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() From 684088c3ce4debbba2f5e767aeba5fc8184d0970 Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Fri, 18 Nov 2022 10:14:40 +0800 Subject: [PATCH 03/10] Rename sliced_decode to slicing --- src/diffusers/models/vae.py | 18 +++++++++--------- .../pipeline_stable_diffusion.py | 12 ++++++------ 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index d24c4db5b420..e46176a32815 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -565,7 +565,7 @@ def __init__( self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) - self.use_sliced_decode = False + self.use_slicing = False def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: h = self.encoder(x) @@ -586,24 +586,24 @@ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decod return DecoderOutput(sample=dec) - def enable_sliced_decode(self): + def enable_slicing(self): r""" Enable sliced VAE decoding. - When this option is enabled, the VAE will split the input tensor in slices, to compute decoding in several - steps. This is useful to save some memory. + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. """ - self.use_sliced_decode = True + self.use_slicing = True - def disable_sliced_decode(self): + def disable_slicing(self): r""" - Disable sliced VAE decoding. If `enable_sliced_decode` was previously invoked, this method will go back to + Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing decoding in one step. """ - self.use_sliced_decode = False + self.use_slicing = False def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: - if self.use_sliced_decode: + if self.use_slicing: decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] decoded = torch.cat(decoded_slices) else: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 9acbf1bcbae7..f063426fa40c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -158,21 +158,21 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) - def enable_sliced_vae_decode(self): + def enable_vae_slicing(self): r""" Enable sliced VAE decoding. - When this option is enabled, the VAE will split the input tensor in slices, to compute decoding in several + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ - self.vae.enable_sliced_decode() + self.vae.enable_slicing() - def disable_sliced_vae_decode(self): + def disable_vae_slicing(self): r""" - Disable sliced VAE decoding. If `enable_sliced_decode` was previously invoked, this method will go back to + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ - self.vae.disable_sliced_decode() + self.vae.disable_slicing() def enable_sequential_cpu_offload(self): r""" From aaca7410ff7d0971c762e63b48010383a82fe5ee Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Fri, 18 Nov 2022 12:26:05 +0800 Subject: [PATCH 04/10] fix whitespace --- src/diffusers/models/vae.py | 2 +- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index e46176a32815..994ecae3426a 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -594,7 +594,7 @@ def enable_slicing(self): steps. This is useful to save some memory and allow larger batch sizes. """ self.use_slicing = True - + def disable_slicing(self): r""" Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 98c9d4cce6af..df4178fc359c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -186,7 +186,7 @@ def enable_vae_slicing(self): steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() - + def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to From b0be8d03424b942a253edf6620445cd4c4a0c4f7 Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Fri, 18 Nov 2022 12:32:23 +0800 Subject: [PATCH 05/10] fix quality check and repository consistency --- src/diffusers/models/vae.py | 4 ++-- .../alt_diffusion/pipeline_alt_diffusion.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 994ecae3426a..b1b46ac787c8 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -597,8 +597,8 @@ def enable_slicing(self): def disable_slicing(self): r""" - Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to - computing decoding in one step. + Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing + decoding in one step. """ self.use_slicing = False diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index afb2c5288640..e63412c14338 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -179,6 +179,22 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, From 30683baf43aebed5767a400f54ae7c3f46a59678 Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Wed, 23 Nov 2022 00:39:08 +0800 Subject: [PATCH 06/10] VAE slicing tests and documentation --- docs/source/optimization/fp16.mdx | 26 +++++++ .../stable_diffusion/test_stable_diffusion.py | 75 +++++++++++++++++++ 2 files changed, 101 insertions(+) diff --git a/docs/source/optimization/fp16.mdx b/docs/source/optimization/fp16.mdx index 4371daacc903..f12517dbc14c 100644 --- a/docs/source/optimization/fp16.mdx +++ b/docs/source/optimization/fp16.mdx @@ -117,6 +117,32 @@ image = pipe(prompt).images[0] There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM! + +## Sliced VAE decode for larger batches + +To decode large batches of images with limited VRAM, or to enable image batches with 32 images or more, you can use sliced VAE decode that decodes the batch latents one image at a time. +You also want to use attention slicing or xformers to run unet with large batch sizes and limited VRAM. + +To perform the VAE decode one image at a time, invoke [`~StableDiffusionPipeline.enable_vae_slicing`] in your pipeline before inference. For example: + +```Python +import torch +from diffusers import StableDiffusionPipeline + +pipe = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="fp16", + torch_dtype=torch.float16, +) +pipe = pipe.to("cuda") + +prompt = "a photo of an astronaut riding a horse on mars" +pipe.enable_vae_slicing() +images = pipe([prompt] * 32).images +``` + +There may also be a small performance boost in VAE decode on multi-image batches. + ## Offloading to CPU with accelerate for memory savings For additional memory savings, you can offload the weights to CPU and load them to GPU when performing the forward pass. diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 17a293e605fd..0f04b3a1a202 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -497,6 +497,42 @@ def test_stable_diffusion_attention_chunk(self): assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4 + def test_stable_diffusion_vae_slicing(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + image_count = 4 + + generator = torch.Generator(device=device).manual_seed(0) + output_1 = sd_pipe([prompt] * image_count, generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + # make sure sliced vae decode yields the same result + sd_pipe.enable_vae_slicing() + generator = torch.Generator(device=device).manual_seed(0) + output_2 = sd_pipe([prompt] * image_count, generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + # there is a small discrepancy at image borders vs. full batch decode + assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 3e-3 + def test_stable_diffusion_negative_prompt(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet @@ -777,6 +813,45 @@ def test_stable_diffusion_memory_chunking(self): assert mem_bytes > 3.75 * 10**9 assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3 + def test_stable_diffusion_vae_slicing(self): + torch.cuda.reset_peak_memory_stats() + model_id = "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "a photograph of an astronaut riding a horse" + + # enable vae slicing + pipe.enable_vae_slicing() + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output_chunked = pipe( + [prompt] * 4, generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image_chunked = output_chunked.images + + mem_bytes = torch.cuda.max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + # make sure that less than 4 GB is allocated + assert mem_bytes < 4e9 + + # disable vae slicing + pipe.disable_vae_slicing() + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output = pipe( + [prompt] * 4, generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image = output.images + + # make sure that more than 4 GB is allocated + mem_bytes = torch.cuda.max_memory_allocated() + assert mem_bytes > 4e9 + # There is a small discrepancy at the image borders vs. a fully batched version. + assert np.abs(image_chunked.flatten() - image.flatten()).max() < 3e-3 + def test_stable_diffusion_text2img_pipeline_fp16(self): torch.cuda.reset_peak_memory_stats() model_id = "CompVis/stable-diffusion-v1-4" From 35324c41e4f4a7b314bbe45f78d8cdc638350932 Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Wed, 23 Nov 2022 00:42:23 +0800 Subject: [PATCH 07/10] API doc hooks for VAE slicing --- docs/source/api/pipelines/stable_diffusion.mdx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/api/pipelines/stable_diffusion.mdx b/docs/source/api/pipelines/stable_diffusion.mdx index 8b551f7a3b17..01d2ee0bf9b4 100644 --- a/docs/source/api/pipelines/stable_diffusion.mdx +++ b/docs/source/api/pipelines/stable_diffusion.mdx @@ -76,6 +76,8 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca - __call__ - enable_attention_slicing - disable_attention_slicing + - enable_vae_slicing + - disable_vae_slicing ## StableDiffusionImg2ImgPipeline [[autodoc]] StableDiffusionImg2ImgPipeline From 4ed1fac863a8fa4240803ce44a5c8fb7f7429a15 Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Wed, 23 Nov 2022 00:45:37 +0800 Subject: [PATCH 08/10] reformat vae slicing tests --- tests/pipelines/stable_diffusion/test_stable_diffusion.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 0f04b3a1a202..560955c77d84 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -523,12 +523,16 @@ def test_stable_diffusion_vae_slicing(self): image_count = 4 generator = torch.Generator(device=device).manual_seed(0) - output_1 = sd_pipe([prompt] * image_count, generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + output_1 = sd_pipe( + [prompt] * image_count, generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np" + ) # make sure sliced vae decode yields the same result sd_pipe.enable_vae_slicing() generator = torch.Generator(device=device).manual_seed(0) - output_2 = sd_pipe([prompt] * image_count, generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + output_2 = sd_pipe( + [prompt] * image_count, generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np" + ) # there is a small discrepancy at image borders vs. full batch decode assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 3e-3 From 51ef9396a3a8acce7d9b280b3fd9ed83fd5fcd98 Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Wed, 23 Nov 2022 22:15:53 +0800 Subject: [PATCH 09/10] Skip VAE slicing for one-image batches --- src/diffusers/models/vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index b1b46ac787c8..e29f4e8afa2f 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -603,7 +603,7 @@ def disable_slicing(self): self.use_slicing = False def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: - if self.use_slicing: + if self.use_slicing and z.shape[0] > 1: decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] decoded = torch.cat(decoded_slices) else: From 85430c7e93bd69ac8d0751a716a8e6f0e81e3435 Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Wed, 23 Nov 2022 22:56:12 +0800 Subject: [PATCH 10/10] Documentation tweaks for VAE slicing --- docs/source/optimization/fp16.mdx | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/source/optimization/fp16.mdx b/docs/source/optimization/fp16.mdx index f12517dbc14c..49fe3876bd4b 100644 --- a/docs/source/optimization/fp16.mdx +++ b/docs/source/optimization/fp16.mdx @@ -120,8 +120,9 @@ There's a small performance penalty of about 10% slower inference times, but thi ## Sliced VAE decode for larger batches -To decode large batches of images with limited VRAM, or to enable image batches with 32 images or more, you can use sliced VAE decode that decodes the batch latents one image at a time. -You also want to use attention slicing or xformers to run unet with large batch sizes and limited VRAM. +To decode large batches of images with limited VRAM, or to enable batches with 32 images or more, you can use sliced VAE decode that decodes the batch latents one image at a time. + +You likely want to couple this with [`~StableDiffusionPipeline.enable_attention_slicing`] or [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use. To perform the VAE decode one image at a time, invoke [`~StableDiffusionPipeline.enable_vae_slicing`] in your pipeline before inference. For example: @@ -141,7 +142,8 @@ pipe.enable_vae_slicing() images = pipe([prompt] * 32).images ``` -There may also be a small performance boost in VAE decode on multi-image batches. +You may see a small performance boost in VAE decode on multi-image batches. There should be no performance impact on single-image batches. + ## Offloading to CPU with accelerate for memory savings