From 46eb217f2346e6e86efae0536e08d98ca9751b83 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 7 Oct 2022 12:35:45 +0200 Subject: [PATCH 1/8] mps: alt. implementation for repeat_interleave --- .../stable_diffusion/pipeline_stable_diffusion.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 00e72de6551a..1faddad16ef2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -219,7 +219,12 @@ def __call__( text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] # duplicate text embeddings for each generation per prompt - text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + if self.device.type == "mps": + # Workaround for `repeat_interleave`. Assumes 3 dims. + d0, d1, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1).view(d0*num_images_per_prompt, d1, -1) + else: + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -257,7 +262,12 @@ def __call__( uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] # duplicate unconditional embeddings for each generation per prompt - uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0) + if self.device.type == "mps": + # Workaround for `repeat_interleave`. Assumes 3 dims. + d0, d1, _ = uncond_embeddings.shape + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1).view(d0*num_images_per_prompt, d1, -1) + else: + uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch From 54b02bb262ca5ca4f7dafcbb3a69be9d67026961 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 7 Oct 2022 12:40:16 +0200 Subject: [PATCH 2/8] style --- .../stable_diffusion/pipeline_stable_diffusion.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 1faddad16ef2..104e222ff01c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -222,7 +222,9 @@ def __call__( if self.device.type == "mps": # Workaround for `repeat_interleave`. Assumes 3 dims. d0, d1, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1).view(d0*num_images_per_prompt, d1, -1) + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1).view( + d0 * num_images_per_prompt, d1, -1 + ) else: text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) @@ -265,7 +267,9 @@ def __call__( if self.device.type == "mps": # Workaround for `repeat_interleave`. Assumes 3 dims. d0, d1, _ = uncond_embeddings.shape - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1).view(d0*num_images_per_prompt, d1, -1) + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1).view( + d0 * num_images_per_prompt, d1, -1 + ) else: uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0) From e4f753cdb100905bc0330e7606869bbd279a0659 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 7 Oct 2022 12:43:06 +0200 Subject: [PATCH 3/8] Bump mps version of PyTorch in the documentation. --- docs/source/optimization/mps.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/optimization/mps.mdx b/docs/source/optimization/mps.mdx index ad347b84e840..ff9d614c870f 100644 --- a/docs/source/optimization/mps.mdx +++ b/docs/source/optimization/mps.mdx @@ -19,7 +19,7 @@ specific language governing permissions and limitations under the License. - Mac computer with Apple silicon (M1/M2) hardware. - macOS 12.3 or later. - arm64 version of Python. -- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.13.0.dev20220830` or later. +- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.14.0.dev20221007` or later. ## Inference Pipeline From 5fe1598962172df48abae4fdfb44b80c67a7614f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 7 Oct 2022 19:00:37 +0200 Subject: [PATCH 4/8] Apply suggestions from code review Co-authored-by: Suraj Patil --- .../stable_diffusion/pipeline_stable_diffusion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 104e222ff01c..2f03cc55bf30 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -221,9 +221,9 @@ def __call__( # duplicate text embeddings for each generation per prompt if self.device.type == "mps": # Workaround for `repeat_interleave`. Assumes 3 dims. - d0, d1, _ = text_embeddings.shape + batch_size, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1).view( - d0 * num_images_per_prompt, d1, -1 + batch_size * num_images_per_prompt, seq_len, -1 ) else: text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) @@ -266,9 +266,9 @@ def __call__( # duplicate unconditional embeddings for each generation per prompt if self.device.type == "mps": # Workaround for `repeat_interleave`. Assumes 3 dims. - d0, d1, _ = uncond_embeddings.shape + batch_size, seq_len, _ = uncond_embeddings.shape uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1).view( - d0 * num_images_per_prompt, d1, -1 + batch_size * num_images_per_prompt, seq_len, -1 ) else: uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0) From f98fb143f49118ca98cdb6fe759934d512e2ad78 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 7 Oct 2022 21:10:41 +0200 Subject: [PATCH 5/8] Simplify: do not check for device. --- .../pipeline_stable_diffusion.py | 28 +++++++------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 2f03cc55bf30..173ab35294ea 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -218,15 +218,11 @@ def __call__( text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] - # duplicate text embeddings for each generation per prompt - if self.device.type == "mps": - # Workaround for `repeat_interleave`. Assumes 3 dims. - batch_size, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1).view( - batch_size * num_images_per_prompt, seq_len, -1 - ) - else: - text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + # duplicate text embeddings for each generation per prompt, using mps friendly method + batch_size, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1).view( + batch_size * num_images_per_prompt, seq_len, -1 + ) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -263,15 +259,11 @@ def __call__( ) uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] - # duplicate unconditional embeddings for each generation per prompt - if self.device.type == "mps": - # Workaround for `repeat_interleave`. Assumes 3 dims. - batch_size, seq_len, _ = uncond_embeddings.shape - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1).view( - batch_size * num_images_per_prompt, seq_len, -1 - ) - else: - uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0) + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + batch_size, seq_len, _ = uncond_embeddings.shape + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1).view( + batch_size * num_images_per_prompt, seq_len, -1 + ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch From 7bf3decee38f7793dc95d81baee1a16e136547ba Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 7 Oct 2022 21:11:51 +0200 Subject: [PATCH 6/8] style --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 173ab35294ea..94904859de91 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -218,7 +218,7 @@ def __call__( text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] - # duplicate text embeddings for each generation per prompt, using mps friendly method + # duplicate text embeddings for each generation per prompt, using mps friendly method batch_size, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1).view( batch_size * num_images_per_prompt, seq_len, -1 From c49bf483b216b2db54335a7d37564b3235b33a84 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 11 Oct 2022 19:53:31 +0200 Subject: [PATCH 7/8] Fix repeat dimensions: - The unconditional embeddings are always created from a single prompt. - I was shadowing the batch_size var. --- .../stable_diffusion/pipeline_stable_diffusion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 6a2907c8de36..4252f5ee1721 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -219,9 +219,9 @@ def __call__( text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] # duplicate text embeddings for each generation per prompt, using mps friendly method - batch_size, seq_len, _ = text_embeddings.shape + bs_embed, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1).view( - batch_size * num_images_per_prompt, seq_len, -1 + bs_embed * num_images_per_prompt, seq_len, -1 ) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) @@ -260,8 +260,8 @@ def __call__( uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - batch_size, seq_len, _ = uncond_embeddings.shape - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1).view( + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1).view( batch_size * num_images_per_prompt, seq_len, -1 ) From 1aec2e5f993a3d60a87e9a39d8c125a7732f160c Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 11 Oct 2022 19:57:19 +0200 Subject: [PATCH 8/8] Split long lines as suggested by Suraj. --- .../stable_diffusion/pipeline_stable_diffusion.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 4252f5ee1721..ca6c580ffc55 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -220,9 +220,8 @@ def __call__( # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1).view( - bs_embed * num_images_per_prompt, seq_len, -1 - ) + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -261,9 +260,8 @@ def __call__( # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1).view( - batch_size * num_images_per_prompt, seq_len, -1 - ) + uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch