diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 25ca322351d3..bba8d4084636 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -109,6 +109,7 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
+ from .loaders import TextualInversionLoaderMixin
from .pipelines import (
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py
index d6bb6fde6ac1..265ea92625f5 100644
--- a/src/diffusers/loaders.py
+++ b/src/diffusers/loaders.py
@@ -13,18 +13,28 @@
# limitations under the License.
import os
from collections import defaultdict
-from typing import Callable, Dict, Union
+from typing import Callable, Dict, List, Optional, Union
import torch
from .models.attention_processor import LoRAAttnProcessor
-from .models.modeling_utils import _get_model_file
-from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging
+from .utils import (
+ DIFFUSERS_CACHE,
+ HF_HUB_OFFLINE,
+ _get_model_file,
+ deprecate,
+ is_safetensors_available,
+ is_transformers_available,
+ logging,
+)
if is_safetensors_available():
import safetensors
+if is_transformers_available():
+ from transformers import PreTrainedModel, PreTrainedTokenizer
+
logger = logging.get_logger(__name__)
@@ -32,6 +42,9 @@
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
+TEXT_INVERSION_NAME = "learned_embeds.bin"
+TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
+
class AttnProcsLayers(torch.nn.Module):
def __init__(self, state_dict: Dict[str, torch.Tensor]):
@@ -123,13 +136,6 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).
-
-
-
-
- Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
- this method in a firewalled environment.
-
"""
@@ -292,5 +298,272 @@ def save_function(weights, filename):
# Save the model
save_function(state_dict, os.path.join(save_directory, weight_name))
-
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
+
+
+class TextualInversionLoaderMixin:
+ r"""
+ Mixin class for loading textual inversion tokens and embeddings to the tokenizer and text encoder.
+ """
+
+ def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: PreTrainedTokenizer):
+ r"""
+ Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
+ to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
+ is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
+ inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
+
+ Parameters:
+ prompt (`str` or list of `str`):
+ The prompt or prompts to guide the image generation.
+ tokenizer (`PreTrainedTokenizer`):
+ The tokenizer responsible for encoding the prompt into input tokens.
+
+ Returns:
+ `str` or list of `str`: The converted prompt
+ """
+ if not isinstance(prompt, List):
+ prompts = [prompt]
+ else:
+ prompts = prompt
+
+ prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
+
+ if not isinstance(prompt, List):
+ return prompts[0]
+
+ return prompts
+
+ def _maybe_convert_prompt(self, prompt: str, tokenizer: PreTrainedTokenizer):
+ r"""
+ Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
+ to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
+ is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
+ inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
+
+ Parameters:
+ prompt (`str`):
+ The prompt to guide the image generation.
+ tokenizer (`PreTrainedTokenizer`):
+ The tokenizer responsible for encoding the prompt into input tokens.
+
+ Returns:
+ `str`: The converted prompt
+ """
+ tokens = tokenizer.tokenize(prompt)
+ for token in tokens:
+ if token in tokenizer.added_tokens_encoder:
+ replacement = token
+ i = 1
+ while f"{token}_{i}" in tokenizer.added_tokens_encoder:
+ replacement += f"{token}_{i}"
+ i += 1
+
+ prompt = prompt.replace(token, replacement)
+
+ return prompt
+
+ def load_textual_inversion(
+ self, pretrained_model_name_or_path: Union[str, Dict[str, torch.Tensor]], token: Optional[str] = None, **kwargs
+ ):
+ r"""
+ Load textual inversion embeddings into the text encoder of stable diffusion pipelines. Both `diffusers` and
+ `Automatic1111` formats are supported.
+
+
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ Valid model ids should have an organization name, like
+ `"sd-concepts-library/low-poly-hd-logos-icons"`.
+ - A path to a *directory* containing textual inversion weights, e.g.
+ `./my_text_inversion_directory/`.
+ weight_name (`str`, *optional*):
+ Name of a custom weight file. This should be used in two cases:
+
+ - The saved textual inversion file is in `diffusers` format, but was saved under a specific weight
+ name, such as `text_inv.bin`.
+ - The saved textual inversion file is in the "Automatic1111" form.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
+ huggingface.co or downloaded locally), you can specify the folder name here.
+
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information.
+
+
+
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+
+ """
+ if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
+ raise ValueError(
+ f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling"
+ f" `{self.load_textual_inversion.__name__}`"
+ )
+
+ if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel):
+ raise ValueError(
+ f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling"
+ f" `{self.load_textual_inversion.__name__}`"
+ )
+
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ if use_safetensors and not is_safetensors_available():
+ raise ValueError(
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
+ )
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = is_safetensors_available()
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "text_inversion",
+ "framework": "pytorch",
+ }
+
+ # 1. Load textual inversion file
+ model_file = None
+ # Let's first try to load .safetensors weights
+ if (use_safetensors and weight_name is None) or (
+ weight_name is not None and weight_name.endswith(".safetensors")
+ ):
+ try:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
+ except Exception as e:
+ if not allow_pickle:
+ raise e
+
+ model_file = None
+
+ if model_file is None:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=weight_name or TEXT_INVERSION_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = torch.load(model_file, map_location="cpu")
+
+ # 2. Load token and embedding correcly from file
+ if isinstance(state_dict, torch.Tensor):
+ if token is None:
+ raise ValueError(
+ "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
+ )
+ embedding = state_dict
+ elif len(state_dict) == 1:
+ # diffusers
+ loaded_token, embedding = next(iter(state_dict.items()))
+ elif "string_to_param" in state_dict:
+ # A1111
+ loaded_token = state_dict["name"]
+ embedding = state_dict["string_to_param"]["*"]
+
+ if token is not None and loaded_token != token:
+ logger.warn(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
+ else:
+ token = loaded_token
+
+ embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
+
+ # 3. Make sure we don't mess up the tokenizer or text encoder
+ vocab = self.tokenizer.get_vocab()
+ if token in vocab:
+ raise ValueError(
+ f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
+ )
+ elif f"{token}_1" in vocab:
+ multi_vector_tokens = [token]
+ i = 1
+ while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
+ multi_vector_tokens.append(f"{token}_{i}")
+ i += 1
+
+ raise ValueError(
+ f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
+ )
+
+ is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
+
+ if is_multi_vector:
+ tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
+ embeddings = [e for e in embedding] # noqa: C416
+ else:
+ tokens = [token]
+ embeddings = [embedding] if len(embedding.shape) > 1 else [embedding[0]]
+
+ # add tokens and get ids
+ self.tokenizer.add_tokens(tokens)
+ token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
+
+ # resize token embeddings and set new embeddings
+ self.text_encoder.resize_token_embeddings(len(self.tokenizer))
+ for token_id, embedding in zip(token_ids, embeddings):
+ self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
+
+ logger.info("Loaded textual inversion embedding for {token}.")
diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py
index 5a5d233fbb4e..6a849f6f0e45 100644
--- a/src/diffusers/models/modeling_utils.py
+++ b/src/diffusers/models/modeling_utils.py
@@ -16,27 +16,22 @@
import inspect
import os
-import warnings
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
-from huggingface_hub import hf_hub_download
-from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
-from packaging import version
-from requests import HTTPError
from torch import Tensor, device
from .. import __version__
from ..utils import (
CONFIG_NAME,
- DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME,
HF_HUB_OFFLINE,
- HUGGINGFACE_CO_RESOLVE_ENDPOINT,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
+ _add_variant,
+ _get_model_file,
is_accelerate_available,
is_safetensors_available,
is_torch_version,
@@ -144,15 +139,6 @@ def load(module: torch.nn.Module, prefix=""):
return error_msgs
-def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
- if variant is not None:
- splits = weights_name.split(".")
- splits = splits[:-1] + [variant] + splits[-1:]
- weights_name = ".".join(splits)
-
- return weights_name
-
-
class ModelMixin(torch.nn.Module):
r"""
Base class for all models.
@@ -789,121 +775,3 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
else:
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
-
-
-def _get_model_file(
- pretrained_model_name_or_path,
- *,
- weights_name,
- subfolder,
- cache_dir,
- force_download,
- proxies,
- resume_download,
- local_files_only,
- use_auth_token,
- user_agent,
- revision,
- commit_hash=None,
-):
- pretrained_model_name_or_path = str(pretrained_model_name_or_path)
- if os.path.isfile(pretrained_model_name_or_path):
- return pretrained_model_name_or_path
- elif os.path.isdir(pretrained_model_name_or_path):
- if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
- # Load from a PyTorch checkpoint
- model_file = os.path.join(pretrained_model_name_or_path, weights_name)
- return model_file
- elif subfolder is not None and os.path.isfile(
- os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
- ):
- model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
- return model_file
- else:
- raise EnvironmentError(
- f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
- )
- else:
- # 1. First check if deprecated way of loading from branches is used
- if (
- revision in DEPRECATED_REVISION_ARGS
- and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME)
- and version.parse(version.parse(__version__).base_version) >= version.parse("0.17.0")
- ):
- try:
- model_file = hf_hub_download(
- pretrained_model_name_or_path,
- filename=_add_variant(weights_name, revision),
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- use_auth_token=use_auth_token,
- user_agent=user_agent,
- subfolder=subfolder,
- revision=revision or commit_hash,
- )
- warnings.warn(
- f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
- FutureWarning,
- )
- return model_file
- except: # noqa: E722
- warnings.warn(
- f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.",
- FutureWarning,
- )
- try:
- # 2. Load model file as usual
- model_file = hf_hub_download(
- pretrained_model_name_or_path,
- filename=weights_name,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- use_auth_token=use_auth_token,
- user_agent=user_agent,
- subfolder=subfolder,
- revision=revision or commit_hash,
- )
- return model_file
-
- except RepositoryNotFoundError:
- raise EnvironmentError(
- f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
- "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
- "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
- "login`."
- )
- except RevisionNotFoundError:
- raise EnvironmentError(
- f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
- "this model name. Check the model page at "
- f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
- )
- except EntryNotFoundError:
- raise EnvironmentError(
- f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
- )
- except HTTPError as err:
- raise EnvironmentError(
- f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
- )
- except ValueError:
- raise EnvironmentError(
- f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
- f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
- f" directory containing a file named {weights_name} or"
- " \nCheckout your internet connection or see how to run the library in"
- " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
- )
- except EnvironmentError:
- raise EnvironmentError(
- f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
- "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
- f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
- f"containing a file named {weights_name}"
- )
diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
index 68ad20c1598a..c5bb8f9ac7b1 100644
--- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
+++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
@@ -22,6 +22,7 @@
from diffusers.utils import is_accelerate_available, is_accelerate_version
from ...configuration_utils import FrozenDict
+from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
@@ -49,7 +50,7 @@
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
-class AltDiffusionPipeline(DiffusionPipeline):
+class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-to-image generation using Alt Diffusion.
@@ -312,6 +313,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -372,6 +377,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
index 3521867f2b9f..9af55d1d018a 100644
--- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
+++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
@@ -25,6 +25,7 @@
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
+from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
@@ -88,7 +89,7 @@ def preprocess(image):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
-class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
+class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-guided image to image generation using Alt Diffusion.
@@ -322,6 +323,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -382,6 +387,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
index 08dad43784f8..dd8e4f16dfc0 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
@@ -24,6 +24,7 @@
from diffusers.utils import is_accelerate_available, is_accelerate_version
from ...configuration_utils import FrozenDict
+from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
@@ -118,7 +119,7 @@ def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta):
return noise
-class CycleDiffusionPipeline(DiffusionPipeline):
+class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-guided image to image generation using Stable Diffusion.
@@ -338,6 +339,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -398,6 +403,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
index b428b4341849..73b9178e3ab1 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -20,6 +20,7 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
+from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -52,7 +53,7 @@
"""
-class StableDiffusionPipeline(DiffusionPipeline):
+class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
@@ -315,6 +316,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -375,6 +380,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py
index ae92ba5526a8..46adb6967140 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py
@@ -21,6 +21,7 @@
from torch.nn import functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention
from ...schedulers import KarrasDiffusionSchedulers
@@ -159,7 +160,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
return hidden_states
-class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
+class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion and Attend and Excite.
@@ -335,6 +336,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -395,6 +400,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py
index d7f84d2e697b..93cbc03b12ed 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py
@@ -23,6 +23,7 @@
from torch import nn
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.controlnet import ControlNetOutput
from ...models.modeling_utils import ModelMixin
@@ -146,7 +147,7 @@ def forward(
return down_block_res_samples, mid_block_res_sample
-class StableDiffusionControlNetPipeline(DiffusionPipeline):
+class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
@@ -354,6 +355,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -414,6 +419,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
index 876b1b8305f2..54f00ebc23f2 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
@@ -23,6 +23,7 @@
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
from ...configuration_utils import FrozenDict
+from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
@@ -54,7 +55,7 @@ def preprocess(image):
return image
-class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
+class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-guided image to image generation using Stable Diffusion.
@@ -200,6 +201,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -260,6 +265,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
index 14512e180992..e47fae663de3 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -23,6 +23,7 @@
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
+from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -91,7 +92,7 @@ def preprocess(image):
return image
-class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
+class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-guided image to image generation using Stable Diffusion.
@@ -329,6 +330,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -389,6 +394,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
index a934f639a508..8e0ea5a8d079 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -22,6 +22,7 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
+from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
@@ -137,7 +138,7 @@ def prepare_mask_and_masked_image(image, mask):
return mask, masked_image
-class StableDiffusionInpaintPipeline(DiffusionPipeline):
+class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
@@ -381,6 +382,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -441,6 +446,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
index feb13d100089..b7a0c942bbe2 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
@@ -22,6 +22,7 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
+from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -81,7 +82,7 @@ def preprocess_mask(mask, scale_factor=8):
return mask
-class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
+class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
@@ -317,6 +318,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -377,6 +382,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
index 40cde74a0596..f7999a08dc9b 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
@@ -20,6 +20,7 @@
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -60,7 +61,7 @@ def preprocess(image):
return image
-class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
+class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion.
@@ -511,6 +512,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -571,6 +576,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
index 6a895a6d0f29..3d10c7d4e8e8 100755
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
@@ -18,6 +18,7 @@
import torch
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
+from ...loaders import TextualInversionLoaderMixin
from ...pipelines import DiffusionPipeline
from ...schedulers import LMSDiscreteScheduler
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor
@@ -41,7 +42,7 @@ def apply_model(self, *args, **kwargs):
return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample
-class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
+class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
@@ -238,6 +239,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -298,6 +303,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py
index 0e850b43bd7c..d841bd8a2d26 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py
@@ -18,6 +18,7 @@
import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import PNDMScheduler
from ...schedulers.scheduling_utils import SchedulerMixin
@@ -52,7 +53,7 @@
"""
-class StableDiffusionModelEditingPipeline(DiffusionPipeline):
+class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-to-image model editing using "Editing Implicit Assumptions in Text-to-Image Diffusion Models".
@@ -266,6 +267,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -326,6 +331,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py
index fdae1ed3679b..c47423bdee5b 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py
@@ -17,6 +17,7 @@
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, PNDMScheduler
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
@@ -47,7 +48,7 @@
"""
-class StableDiffusionPanoramaPipeline(DiffusionPipeline):
+class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-to-image generation using "MultiDiffusion: Fusing Diffusion Paths for Controlled Image
Generation".
@@ -230,6 +231,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -290,6 +295,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
index 89cf823a1f7e..6af923cb7743 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
@@ -28,6 +28,7 @@
CLIPTokenizer,
)
+from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention
from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler
@@ -50,7 +51,7 @@
@dataclass
-class Pix2PixInversionPipelineOutput(BaseOutput):
+class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin):
"""
Output class for Stable Diffusion pipelines.
@@ -470,6 +471,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -530,6 +535,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py
index d77e3550fc75..2b08cf662bb4 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py
@@ -19,6 +19,7 @@
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
@@ -87,7 +88,7 @@ def __call__(
# Modified to get self-attention guidance scale in this paper (https://arxiv.org/pdf/2210.00939.pdf) as an input
-class StableDiffusionSAGPipeline(DiffusionPipeline):
+class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
@@ -247,6 +248,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -307,6 +312,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
index e21b41ccac6d..606202bd3911 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
@@ -20,6 +20,7 @@
import torch
from transformers import CLIPTextModel, CLIPTokenizer
+from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
@@ -50,7 +51,7 @@ def preprocess(image):
return image
-class StableDiffusionUpscalePipeline(DiffusionPipeline):
+class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-guided image super-resolution using Stable Diffusion 2.
@@ -194,6 +195,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -254,6 +259,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
index 9c3d39564f6e..ce41572e683c 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
@@ -19,6 +19,7 @@
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
+from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding
from ...schedulers import KarrasDiffusionSchedulers
@@ -47,7 +48,7 @@
"""
-class StableUnCLIPPipeline(DiffusionPipeline):
+class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
"""
Pipeline for text-to-image generation using stable unCLIP.
@@ -367,6 +368,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -427,6 +432,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
index c8fb3f8021b9..b9bf00bc7835 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
@@ -21,6 +21,7 @@
from diffusers.utils.import_utils import is_accelerate_available
+from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding
from ...schedulers import KarrasDiffusionSchedulers
@@ -60,7 +61,7 @@
"""
-class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
+class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
"""
Pipeline for text-guided image to image generation using stable unCLIP.
@@ -267,6 +268,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -327,6 +332,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
index 9129ae0118b8..1cbe78f0c964 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
@@ -19,6 +19,7 @@
import torch
from transformers import CLIPTextModel, CLIPTokenizer
+from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -72,7 +73,7 @@ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -
return images
-class TextToVideoSDPipeline(DiffusionPipeline):
+class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-to-video generation.
@@ -256,6 +257,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -316,6 +321,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py
index 615804c91a19..3a1103ac1adf 100644
--- a/src/diffusers/utils/__init__.py
+++ b/src/diffusers/utils/__init__.py
@@ -37,6 +37,8 @@
from .dynamic_modules_utils import get_class_from_dynamic_module
from .hub_utils import (
HF_HUB_OFFLINE,
+ _add_variant,
+ _get_model_file,
extract_commit_hash,
http_user_agent,
)
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index ab85566049d8..cf85ff157f57 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -2,6 +2,21 @@
from ..utils import DummyObject, requires_backends
+class TextualInversionLoaderMixin(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class AltDiffusionImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py
index 916b18d35e7e..511763ec6687 100644
--- a/src/diffusers/utils/hub_utils.py
+++ b/src/diffusers/utils/hub_utils.py
@@ -18,16 +18,30 @@
import re
import sys
import traceback
+import warnings
from pathlib import Path
from typing import Dict, Optional, Union
from uuid import uuid4
-from huggingface_hub import HfFolder, ModelCard, ModelCardData, whoami
+from huggingface_hub import HfFolder, ModelCard, ModelCardData, hf_hub_download, whoami
from huggingface_hub.file_download import REGEX_COMMIT_HASH
-from huggingface_hub.utils import is_jinja_available
+from huggingface_hub.utils import (
+ EntryNotFoundError,
+ RepositoryNotFoundError,
+ RevisionNotFoundError,
+ is_jinja_available,
+)
+from packaging import version
+from requests import HTTPError
from .. import __version__
-from .constants import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT
+from .constants import (
+ DEPRECATED_REVISION_ARGS,
+ DIFFUSERS_CACHE,
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
+ SAFETENSORS_WEIGHTS_NAME,
+ WEIGHTS_NAME,
+)
from .import_utils import (
ENV_VARS_TRUE_VALUES,
_flax_version,
@@ -215,3 +229,130 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str]
f"There was a problem when trying to write in your cache folder ({DIFFUSERS_CACHE}). Please, ensure "
"the directory exists and can be written to."
)
+
+
+def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
+ if variant is not None:
+ splits = weights_name.split(".")
+ splits = splits[:-1] + [variant] + splits[-1:]
+ weights_name = ".".join(splits)
+
+ return weights_name
+
+
+def _get_model_file(
+ pretrained_model_name_or_path,
+ *,
+ weights_name,
+ subfolder,
+ cache_dir,
+ force_download,
+ proxies,
+ resume_download,
+ local_files_only,
+ use_auth_token,
+ user_agent,
+ revision,
+ commit_hash=None,
+):
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ if os.path.isfile(pretrained_model_name_or_path):
+ return pretrained_model_name_or_path
+ elif os.path.isdir(pretrained_model_name_or_path):
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
+ # Load from a PyTorch checkpoint
+ model_file = os.path.join(pretrained_model_name_or_path, weights_name)
+ return model_file
+ elif subfolder is not None and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
+ ):
+ model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
+ return model_file
+ else:
+ raise EnvironmentError(
+ f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
+ )
+ else:
+ # 1. First check if deprecated way of loading from branches is used
+ if (
+ revision in DEPRECATED_REVISION_ARGS
+ and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME)
+ and version.parse(version.parse(__version__).base_version) >= version.parse("0.17.0")
+ ):
+ try:
+ model_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=_add_variant(weights_name, revision),
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ subfolder=subfolder,
+ revision=revision or commit_hash,
+ )
+ warnings.warn(
+ f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
+ FutureWarning,
+ )
+ return model_file
+ except: # noqa: E722
+ warnings.warn(
+ f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.",
+ FutureWarning,
+ )
+ try:
+ # 2. Load model file as usual
+ model_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=weights_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ subfolder=subfolder,
+ revision=revision or commit_hash,
+ )
+ return model_file
+
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
+ "login`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
+ "this model name. Check the model page at "
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
+ )
+ except HTTPError as err:
+ raise EnvironmentError(
+ f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a file named {weights_name} or"
+ " \nCheckout your internet connection or see how to run the library in"
+ " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a file named {weights_name}"
+ )
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
index f4e8113a298f..c3ad88b34acb 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
@@ -21,6 +21,7 @@
import numpy as np
import torch
+from huggingface_hub import hf_hub_download
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
@@ -886,6 +887,32 @@ def test_stable_diffusion_pipeline_with_model_offloading(self):
assert mem_bytes_slicing < mem_bytes_offloaded
assert mem_bytes_slicing < 3 * 10**9
+ def test_stable_diffusion_textual_inversion(self):
+ pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
+ pipe.load_textual_inversion("sd-concepts-library/low-poly-hd-logos-icons")
+
+ a111_file = hf_hub_download("hf-internal-testing/text_inv_embedding_a1111_format", "winter_style.pt")
+ a111_file_neg = hf_hub_download(
+ "hf-internal-testing/text_inv_embedding_a1111_format", "winter_style_negative.pt"
+ )
+ pipe.load_textual_inversion(a111_file)
+ pipe.load_textual_inversion(a111_file_neg)
+ pipe.to("cuda")
+
+ generator = torch.Generator(device="cpu").manual_seed(1)
+
+ prompt = "An logo of a turtle in strong Style-Winter with "
+ neg_prompt = "Style-Winter-neg"
+
+ image = pipe(prompt=prompt, negative_prompt=neg_prompt, generator=generator, output_type="np").images[0]
+
+ expected_image = load_numpy(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_inv/winter_logo_style.npy"
+ )
+
+ max_diff = np.abs(expected_image - image).max()
+ assert max_diff < 5e-3
+
@nightly
@require_torch_gpu
diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py
index 2616223c5447..0525eaca50da 100644
--- a/tests/test_pipelines.py
+++ b/tests/test_pipelines.py
@@ -362,6 +362,97 @@ def test_download_broken_variant(self):
diffusers.utils.import_utils._safetensors_available = True
+ def test_text_inversion_download(self):
+ pipe = StableDiffusionPipeline.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
+ )
+ pipe = pipe.to(torch_device)
+
+ num_tokens = len(pipe.tokenizer)
+
+ # single token load local
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ ten = {"<*>": torch.ones((32,))}
+ torch.save(ten, os.path.join(tmpdirname, "learned_embeds.bin"))
+
+ pipe.load_textual_inversion(tmpdirname)
+
+ token = pipe.tokenizer.convert_tokens_to_ids("<*>")
+ assert token == num_tokens, "Added token must be at spot `num_tokens`"
+ assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 32
+ assert pipe._maybe_convert_prompt("<*>", pipe.tokenizer) == "<*>"
+
+ prompt = "hey <*>"
+ out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
+ assert out.shape == (1, 128, 128, 3)
+
+ # single token load local with weight name
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ ten = {"<**>": 2 * torch.ones((1, 32))}
+ torch.save(ten, os.path.join(tmpdirname, "learned_embeds.bin"))
+
+ pipe.load_textual_inversion(tmpdirname, weight_name="learned_embeds.bin")
+
+ token = pipe.tokenizer.convert_tokens_to_ids("<**>")
+ assert token == num_tokens + 1, "Added token must be at spot `num_tokens`"
+ assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 64
+ assert pipe._maybe_convert_prompt("<**>", pipe.tokenizer) == "<**>"
+
+ prompt = "hey <**>"
+ out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
+ assert out.shape == (1, 128, 128, 3)
+
+ # multi token load
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ ten = {"<***>": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))])}
+ torch.save(ten, os.path.join(tmpdirname, "learned_embeds.bin"))
+
+ pipe.load_textual_inversion(tmpdirname)
+
+ token = pipe.tokenizer.convert_tokens_to_ids("<***>")
+ token_1 = pipe.tokenizer.convert_tokens_to_ids("<***>_1")
+ token_2 = pipe.tokenizer.convert_tokens_to_ids("<***>_2")
+
+ assert token == num_tokens + 2, "Added token must be at spot `num_tokens`"
+ assert token_1 == num_tokens + 3, "Added token must be at spot `num_tokens`"
+ assert token_2 == num_tokens + 4, "Added token must be at spot `num_tokens`"
+ assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96
+ assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128
+ assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160
+ assert pipe._maybe_convert_prompt("<***>", pipe.tokenizer) == "<***><***>_1<***>_2"
+
+ prompt = "hey <***>"
+ out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
+ assert out.shape == (1, 128, 128, 3)
+
+ # multi token load a1111
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ ten = {
+ "string_to_param": {
+ "*": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))])
+ },
+ "name": "<****>",
+ }
+ torch.save(ten, os.path.join(tmpdirname, "a1111.bin"))
+
+ pipe.load_textual_inversion(tmpdirname, weight_name="a1111.bin")
+
+ token = pipe.tokenizer.convert_tokens_to_ids("<****>")
+ token_1 = pipe.tokenizer.convert_tokens_to_ids("<****>_1")
+ token_2 = pipe.tokenizer.convert_tokens_to_ids("<****>_2")
+
+ assert token == num_tokens + 5, "Added token must be at spot `num_tokens`"
+ assert token_1 == num_tokens + 6, "Added token must be at spot `num_tokens`"
+ assert token_2 == num_tokens + 7, "Added token must be at spot `num_tokens`"
+ assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96
+ assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128
+ assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160
+ assert pipe._maybe_convert_prompt("<****>", pipe.tokenizer) == "<****><****>_1<****>_2"
+
+ prompt = "hey <****>"
+ out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
+ assert out.shape == (1, 128, 128, 3)
+
class CustomPipelineTests(unittest.TestCase):
def test_load_custom_pipeline(self):