From ac266fb0d75b4400d4042e0632fca672f6f2c9e3 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 29 Nov 2024 13:13:24 +0200 Subject: [PATCH 01/16] add multiple prompts to flux redux --- .../pipelines/flux/pipeline_flux_prior_redux.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index cf50e89ca5ae..90c3336096d1 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -334,6 +334,10 @@ def encode_prompt( def __call__( self, image: PipelineImageInput, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, return_dict: bool = True, ): r""" @@ -378,10 +382,10 @@ def __call__( pooled_prompt_embeds, _, ) = self.encode_prompt( - prompt=[""] * batch_size, - prompt_2=None, - prompt_embeds=None, - pooled_prompt_embeds=None, + prompt=prompt * batch_size, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, device=device, num_images_per_prompt=1, max_sequence_length=512, From bf2e1492d4abbe0b4352253a2be58b141b1951ba Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 29 Nov 2024 13:29:27 +0200 Subject: [PATCH 02/16] check inputs --- .../flux/pipeline_flux_prior_redux.py | 35 ++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 90c3336096d1..8f14c36a3d04 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -142,6 +142,34 @@ def __init__( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) + def check_inputs( + self, + prompt, + prompt_2, + prompt_embeds=None, + pooled_prompt_embeds=None, + ): + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + def encode_image(self, image, device, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype image = self.feature_extractor.preprocess( @@ -367,6 +395,11 @@ def __call__( batch_size = len(image) else: batch_size = image.shape[0] + if prompt is not None and isinstance(prompt, str): + prompt = batch_size * [prompt] + + + device = self._execution_device # 3. Prepare image embeddings @@ -382,7 +415,7 @@ def __call__( pooled_prompt_embeds, _, ) = self.encode_prompt( - prompt=prompt * batch_size, + prompt=prompt, prompt_2=prompt_2, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, From 27acef87bd1b615ae3de067280421d2078c1f580 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 29 Nov 2024 15:46:33 +0200 Subject: [PATCH 03/16] check inputs --- .../pipelines/flux/pipeline_flux_prior_redux.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 8f14c36a3d04..91c99c1dd326 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -146,6 +146,7 @@ def check_inputs( self, prompt, prompt_2, + image, prompt_embeds=None, pooled_prompt_embeds=None, ): @@ -164,7 +165,8 @@ def check_inputs( raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") - + if prompt is not None and (isinstance(prompt, list) and isinstance(image, list) and len(prompt) != len(image)): + raise ValueError(f"number of prompts must be equal to number of images, but {len(prompt)} prompts were provided and {batch_size} images") if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." @@ -388,6 +390,15 @@ def __call__( returning a tuple, the first element is a list with the generated images. """ + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + ) + # 2. Define call parameters if image is not None and isinstance(image, Image.Image): batch_size = 1 @@ -398,8 +409,6 @@ def __call__( if prompt is not None and isinstance(prompt, str): prompt = batch_size * [prompt] - - device = self._execution_device # 3. Prepare image embeddings From 6fbf2906f44a25f779ac89a93246154f2cae7d70 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 29 Nov 2024 15:55:01 +0200 Subject: [PATCH 04/16] doc --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 91c99c1dd326..3530da8d5152 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -379,6 +379,14 @@ def __call__( numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.flux.FluxPriorReduxPipelineOutput`] instead of a plain tuple. From e6f26b9cfe40f2b14567df4e7d4284b99559b7ce Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 29 Nov 2024 15:58:43 +0200 Subject: [PATCH 05/16] doc --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 3530da8d5152..04e3eb1dcdfc 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -368,6 +368,7 @@ def __call__( prompt_2: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + scales: Optional[Union[float, List[float]]] = None, return_dict: bool = True, ): r""" From 7198ec300e2d36ddab17a68a458eeada228d30ed Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 29 Nov 2024 16:01:43 +0200 Subject: [PATCH 06/16] fix error --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 04e3eb1dcdfc..a5fbca24683b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -166,7 +166,7 @@ def check_inputs( elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if prompt is not None and (isinstance(prompt, list) and isinstance(image, list) and len(prompt) != len(image)): - raise ValueError(f"number of prompts must be equal to number of images, but {len(prompt)} prompts were provided and {batch_size} images") + raise ValueError(f"number of prompts must be equal to number of images, but {len(prompt)} prompts were provided and {len(image)} images") if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." From 382e5569e3c877ba8d893c797d06a06598d65e50 Mon Sep 17 00:00:00 2001 From: Linoy Date: Fri, 29 Nov 2024 14:02:04 +0000 Subject: [PATCH 07/16] style --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 04e3eb1dcdfc..a6a93715c290 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -150,7 +150,6 @@ def check_inputs( prompt_embeds=None, pooled_prompt_embeds=None, ): - if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -166,7 +165,9 @@ def check_inputs( elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if prompt is not None and (isinstance(prompt, list) and isinstance(image, list) and len(prompt) != len(image)): - raise ValueError(f"number of prompts must be equal to number of images, but {len(prompt)} prompts were provided and {batch_size} images") + raise ValueError( + f"number of prompts must be equal to number of images, but {len(prompt)} prompts were provided and {batch_size} images" + ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." From 7d13a416676d1aef1ffa826184ded36dc593600a Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sun, 1 Dec 2024 23:02:41 +0200 Subject: [PATCH 08/16] weighted sum --- .../flux/pipeline_flux_prior_redux.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 343b5a0a2001..01b4d0dae3ed 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -369,7 +369,8 @@ def __call__( prompt_2: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - scales: Optional[Union[float, List[float]]] = None, + prompt_embeds_scale: Optional[Union[float, List[float]]] = 1., + pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1., return_dict: bool = True, ): r""" @@ -418,6 +419,10 @@ def __call__( batch_size = image.shape[0] if prompt is not None and isinstance(prompt, str): prompt = batch_size * [prompt] + if isinstance(prompt_embeds_scale, float): + prompt_embeds_scale = batch_size * [prompt_embeds_scale] + if isinstance(pooled_prompt_embeds_scale, float): + pooled_prompt_embeds_scale = batch_size * [pooled_prompt_embeds_scale] device = self._execution_device @@ -449,9 +454,21 @@ def __call__( # pooled_prompt_embeds is 768, clip text encoder hidden size pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype) + print("1 prompt_embeds.shape", prompt_embeds.shape) + prompt_embeds_scale = torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None] + pooled_prompt_embeds_scale = torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None] + + # Concatenate image and text embeddings prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) - + print("2 prompt_embeds.shape", prompt_embeds.shape) + prompt_embeds *= prompt_embeds_scale + pooled_prompt_embeds *= pooled_prompt_embeds_scale + print("3 prompt_embeds.shape", prompt_embeds.shape) + + prompt_embeds = torch.sum(prompt_embeds, dim=0) + pooled_prompt_embeds = torch.sum(pooled_prompt_embeds, dim=0) + print("4 prompt_embeds.shape", prompt_embeds.shape) # Offload all models self.maybe_free_model_hooks() From b8dfdf7c22f31d83d1b343bcfd1c9ed6e2b2c0c6 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 2 Dec 2024 13:00:12 +0200 Subject: [PATCH 09/16] fix --- .../flux/pipeline_flux_prior_redux.py | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 01b4d0dae3ed..696ac3c38f58 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -454,21 +454,16 @@ def __call__( # pooled_prompt_embeds is 768, clip text encoder hidden size pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype) - print("1 prompt_embeds.shape", prompt_embeds.shape) - prompt_embeds_scale = torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None] - pooled_prompt_embeds_scale = torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None] + # scale & oncatenate image and text embeddings + prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) + prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None] + pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None] + + # weighted sum + prompt_embeds = torch.sum(prompt_embeds, dim=0, keepdim=True) + pooled_prompt_embeds = torch.sum(pooled_prompt_embeds, dim=0, keepdim=True) - # Concatenate image and text embeddings - prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) - print("2 prompt_embeds.shape", prompt_embeds.shape) - prompt_embeds *= prompt_embeds_scale - pooled_prompt_embeds *= pooled_prompt_embeds_scale - print("3 prompt_embeds.shape", prompt_embeds.shape) - - prompt_embeds = torch.sum(prompt_embeds, dim=0) - pooled_prompt_embeds = torch.sum(pooled_prompt_embeds, dim=0) - print("4 prompt_embeds.shape", prompt_embeds.shape) # Offload all models self.maybe_free_model_hooks() From 012a0ec024988e46f26c1ce8c6b5a6a27857ffab Mon Sep 17 00:00:00 2001 From: Linoy Date: Mon, 2 Dec 2024 11:02:37 +0000 Subject: [PATCH 10/16] style --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 696ac3c38f58..c6f3cc155c3f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -369,8 +369,8 @@ def __call__( prompt_2: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - prompt_embeds_scale: Optional[Union[float, List[float]]] = 1., - pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1., + prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, + pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, return_dict: bool = True, ): r""" @@ -458,7 +458,9 @@ def __call__( prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None] - pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None] + pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[ + :, None + ] # weighted sum prompt_embeds = torch.sum(prompt_embeds, dim=0, keepdim=True) From 5af6811007e0ec88d96b83cceeab11c98b120fe0 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 2 Dec 2024 13:10:05 +0200 Subject: [PATCH 11/16] check len of scales --- .../pipelines/flux/pipeline_flux_prior_redux.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 696ac3c38f58..6066e096799a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -144,11 +144,14 @@ def __init__( def check_inputs( self, + image, prompt, prompt_2, - image, prompt_embeds=None, pooled_prompt_embeds=None, + prompt_embeds_scale=1., + pooled_prompt_embeds_scale=1., + ): if prompt is not None and prompt_embeds is not None: raise ValueError( @@ -172,6 +175,10 @@ def check_inputs( raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) + if isinstance(prompt_embeds_scale, list) and (isinstance(image, list) and len(prompt_embeds_scale) != len(image)): + raise ValueError( + f"number of weights must be equal to number of images, but {len(prompt_embeds_scale)} weights were provided and {len(image)} images" + ) def encode_image(self, image, device, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype From bf68f2e1bbbf04afdf06b8c91290d216460c6024 Mon Sep 17 00:00:00 2001 From: Linoy Date: Mon, 2 Dec 2024 11:10:59 +0000 Subject: [PATCH 12/16] style --- .../pipelines/flux/pipeline_flux_prior_redux.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index ec95a467f4a5..79bcf9c1e6cd 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -149,9 +149,8 @@ def check_inputs( prompt_2, prompt_embeds=None, pooled_prompt_embeds=None, - prompt_embeds_scale=1., - pooled_prompt_embeds_scale=1., - + prompt_embeds_scale=1.0, + pooled_prompt_embeds_scale=1.0, ): if prompt is not None and prompt_embeds is not None: raise ValueError( @@ -175,7 +174,9 @@ def check_inputs( raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) - if isinstance(prompt_embeds_scale, list) and (isinstance(image, list) and len(prompt_embeds_scale) != len(image)): + if isinstance(prompt_embeds_scale, list) and ( + isinstance(image, list) and len(prompt_embeds_scale) != len(image) + ): raise ValueError( f"number of weights must be equal to number of images, but {len(prompt_embeds_scale)} weights were provided and {len(image)} images" ) From d2b4881d566285f5e4a15350daa7b7c8ab01158d Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 2 Dec 2024 21:54:25 +0200 Subject: [PATCH 13/16] Update src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py Co-authored-by: hlky --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 79bcf9c1e6cd..3e178ca05b65 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -462,7 +462,7 @@ def __call__( # pooled_prompt_embeds is 768, clip text encoder hidden size pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype) - # scale & oncatenate image and text embeddings + # scale & concatenate image and text embeddings prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None] From a9e893e98aaa55ad9076c524a7d5e797d8e4bfed Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 2 Dec 2024 21:55:41 +0200 Subject: [PATCH 14/16] fix check_inputs call --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 3e178ca05b65..f56e9a9f85e8 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -411,11 +411,13 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( + image, prompt, prompt_2, - image, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, + prompt_embeds_scale=prompt_embeds_scale, + pooled_prompt_embeds_scale=pooled_prompt_embeds_scale, ) # 2. Define call parameters From 7c93dd0208579fd3225fa0a77ba7a24f072cdd44 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 3 Dec 2024 13:29:11 +0200 Subject: [PATCH 15/16] add warning in doc on providing prompts --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index f56e9a9f85e8..80eb1d5d259a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -392,6 +392,8 @@ def __call__( list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. + **experimental feature**: to use this feature, make sure to explicitly load text encoders to + the pipeline. Prompts will be ignored if text encoders are not loaded. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. prompt_embeds (`torch.FloatTensor`, *optional*): @@ -459,6 +461,11 @@ def __call__( lora_scale=None, ) else: + if prompt is not None: + logger.warning( + "prompt input is ignored when text encoders are not loaded to the pipeline. " + "Make sure to explicitly load the text encoders to enable prompt input. " + ) # max_sequence_length is 512, t5 encoder hidden size is 4096 prompt_embeds = torch.zeros((batch_size, 512, 4096), device=device, dtype=image_embeds.dtype) # pooled_prompt_embeds is 768, clip text encoder hidden size From 1ab7060858ceec073216531d1db4dcc1c467bfb5 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 4 Dec 2024 16:35:45 +0000 Subject: [PATCH 16/16] make style --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 80eb1d5d259a..f53958df2ed0 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -391,9 +391,9 @@ def __call__( or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. - **experimental feature**: to use this feature, make sure to explicitly load text encoders to - the pipeline. Prompts will be ignored if text encoders are not loaded. + The prompt or prompts to guide the image generation. **experimental feature**: to use this feature, + make sure to explicitly load text encoders to the pipeline. Prompts will be ignored if text encoders + are not loaded. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. prompt_embeds (`torch.FloatTensor`, *optional*):