Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
PaintByExamplePipeline,
SemanticStableDiffusionPipeline,
StableDiffusionAttendAndExcitePipeline,
StableDiffusionControlNetImg2ImgReferencePipeline,
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .audioldm import AudioLDMPipeline
from .controlnet import (
StableDiffusionControlNetImg2ImgReferencePipeline,
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/controlnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .pipeline_controlnet import StableDiffusionControlNetPipeline
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
from .pipeline_controlnet_img2img_reference import StableDiffusionControlNetImg2ImgReferencePipeline


if is_transformers_available() and is_flax_available():
Expand Down
59 changes: 43 additions & 16 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ def _execution_device(self):
return torch.device(module._hf_hook.execution_device)
return self.device

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
def encode_prompt(
self,
prompt,
device,
Expand All @@ -317,12 +317,13 @@ def _encode_prompt(
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.

Args:
prompt (`str` or `List[str]`, *optional*):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
Expand All @@ -342,13 +343,22 @@ def _encode_prompt(
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
Expand Down Expand Up @@ -387,13 +397,31 @@ def _encode_prompt(
else:
attention_mask = None

prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = prompt_embeds[0]
else:
prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)

if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype

prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
Expand Down Expand Up @@ -449,17 +477,16 @@ def _encode_prompt(
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]

negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)

negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds
return prompt_embeds, negative_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
Expand Down
Loading