diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 4cf4cf36e4e4..5abee475d9c9 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -186,6 +186,8 @@
title: Audio Diffusion
- local: api/pipelines/audioldm
title: AudioLDM
+ - local: api/pipelines/auto_pipeline
+ title: AutoPipeline
- local: api/pipelines/consistency_models
title: Consistency Models
- local: api/pipelines/controlnet
diff --git a/docs/source/en/api/pipelines/auto_pipeline.mdx b/docs/source/en/api/pipelines/auto_pipeline.mdx
new file mode 100644
index 000000000000..4ae2b86ac269
--- /dev/null
+++ b/docs/source/en/api/pipelines/auto_pipeline.mdx
@@ -0,0 +1,68 @@
+
+
+# AutoPipeline
+
+In many cases, one checkpoint can be used for multiple tasks. For example, you may be able to use the same checkpoint for Text-to-Image, Image-to-Image, and Inpainting. However, you'll need to know the pipeline class names linked to your checkpoint.
+
+AutoPipeline is designed to make it easy for you to use multiple pipelines in your workflow. We currently provide 3 AutoPipeline classes to perform three different tasks, i.e. [`AutoPipelineForText2Image`], [`AutoPipelineForImage2Image`], and [`AutoPipelineForInpainting`]. You'll need to choose the AutoPipeline class based on the task you want to perform and use it to automatically retrieve the relevant pipeline given the name/path to the pre-trained weights.
+
+For example, to perform Image-to-Image with the SD1.5 checkpoint, you can do
+
+```python
+from diffusers import PipelineForImageToImage
+
+pipe_i2i = PipelineForImageoImage.from_pretrained("runwayml/stable-diffusion-v1-5")
+```
+
+It will also help you switch between tasks seamlessly using the same checkpoint without reallocating additional memory. For example, to re-use the Image-to-Image pipeline we just created for inpainting, you can do
+
+```python
+from diffusers import PipelineForInpainting
+
+pipe_inpaint = AutoPipelineForInpainting.from_pipe(pipe_i2i)
+```
+All the components will be transferred to the inpainting pipeline with zero cost.
+
+
+Currently AutoPipeline support the Text-to-Image, Image-to-Image, and Inpainting tasks for below diffusion models:
+- [stable Diffusion](./stable_diffusion)
+- [Stable Diffusion Controlnet](./api/pipelines/controlnet)
+- [Stable Diffusion XL](./stable_diffusion/stable_diffusion_xl)
+- [IF](./if)
+- [Kandinsky](./kandinsky)(./kandinsky)(./kandinsky)(./kandinsky)(./kandinsky)
+- [Kandinsky 2.2]()(./kandinsky)
+
+
+## AutoPipelineForText2Image
+
+[[autodoc]] AutoPipelineForText2Image
+ - all
+ - from_pretrained
+ - from_pipe
+
+
+## AutoPipelineForImage2Image
+
+[[autodoc]] AutoPipelineForImage2Image
+ - all
+ - from_pretrained
+ - from_pipe
+
+## AutoPipelineForInpainting
+
+[[autodoc]] AutoPipelineForInpainting
+ - all
+ - from_pretrained
+ - from_pipe
+
+
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 445e63b4d406..81b863c7e65e 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -61,6 +61,9 @@
)
from .pipelines import (
AudioPipelineOutput,
+ AutoPipelineForImage2Image,
+ AutoPipelineForInpainting,
+ AutoPipelineForText2Image,
ConsistencyModelPipeline,
DanceDiffusionPipeline,
DDIMPipeline,
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 802ae4f5bc94..cd667292f907 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -17,6 +17,7 @@
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
+ from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image
from .consistency_models import ConsistencyModelPipeline
from .dance_diffusion import DanceDiffusionPipeline
from .ddim import DDIMPipeline
diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
new file mode 100644
index 000000000000..c827231ada7d
--- /dev/null
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -0,0 +1,834 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from collections import OrderedDict
+
+from ..configuration_utils import ConfigMixin
+from .controlnet import (
+ StableDiffusionControlNetImg2ImgPipeline,
+ StableDiffusionControlNetInpaintPipeline,
+ StableDiffusionControlNetPipeline,
+ StableDiffusionXLControlNetPipeline,
+)
+from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
+from .kandinsky import KandinskyImg2ImgPipeline, KandinskyInpaintPipeline, KandinskyPipeline
+from .kandinsky2_2 import KandinskyV22Img2ImgPipeline, KandinskyV22InpaintPipeline, KandinskyV22Pipeline
+from .stable_diffusion import (
+ StableDiffusionImg2ImgPipeline,
+ StableDiffusionInpaintPipeline,
+ StableDiffusionPipeline,
+)
+from .stable_diffusion_xl import (
+ StableDiffusionXLImg2ImgPipeline,
+ StableDiffusionXLInpaintPipeline,
+ StableDiffusionXLPipeline,
+)
+
+
+AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
+ [
+ ("stable-diffusion", StableDiffusionPipeline),
+ ("stable-diffusion-xl", StableDiffusionXLPipeline),
+ ("if", IFPipeline),
+ ("kandinsky", KandinskyPipeline),
+ ("kandinsky22", KandinskyV22Pipeline),
+ ("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
+ ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
+ ]
+)
+
+AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
+ [
+ ("stable-diffusion", StableDiffusionImg2ImgPipeline),
+ ("stable-diffusion-xl", StableDiffusionXLImg2ImgPipeline),
+ ("if", IFImg2ImgPipeline),
+ ("kandinsky", KandinskyImg2ImgPipeline),
+ ("kandinsky22", KandinskyV22Img2ImgPipeline),
+ ("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
+ ]
+)
+
+AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
+ [
+ ("stable-diffusion", StableDiffusionInpaintPipeline),
+ ("stable-diffusion-xl", StableDiffusionXLInpaintPipeline),
+ ("if", IFInpaintingPipeline),
+ ("kandinsky", KandinskyInpaintPipeline),
+ ("kandinsky22", KandinskyV22InpaintPipeline),
+ ("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
+ ]
+)
+
+SUPPORTED_TASKS_MAPPINGS = [
+ AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
+ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
+ AUTO_INPAINT_PIPELINES_MAPPING,
+]
+
+
+def _get_task_class(mapping, pipeline_class_name):
+ def get_model(pipeline_class_name):
+ for task_mapping in SUPPORTED_TASKS_MAPPINGS:
+ for model_name, pipeline in task_mapping.items():
+ if pipeline.__name__ == pipeline_class_name:
+ return model_name
+
+ model_name = get_model(pipeline_class_name)
+
+ if model_name is not None:
+ task_class = mapping.get(model_name, None)
+ if task_class is not None:
+ return task_class
+ raise ValueError(f"AutoPipeline can't find a pipeline linked to {pipeline_class_name} for {model_name}")
+
+
+def _get_signature_keys(obj):
+ parameters = inspect.signature(obj.__init__).parameters
+ required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
+ optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
+ expected_modules = set(required_parameters.keys()) - {"self"}
+ return expected_modules, optional_parameters
+
+
+class AutoPipelineForText2Image(ConfigMixin):
+ r"""
+
+ AutoPipeline for text-to-image generation.
+
+ [`AutoPipelineForText2Image`] is a generic pipeline class that will be instantiated as one of the text-to-image
+ pipeline class in diffusers.
+
+ The pipeline type (for example [`StableDiffusionPipeline`]) is automatically selected when created with the
+ AutoPipelineForText2Image.from_pretrained(pretrained_model_name_or_path) or
+ AutoPipelineForText2Image.from_pipe(pipeline) class methods .
+
+ This class cannot be instantiated using __init__() (throws an error).
+
+ Class attributes:
+
+ - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the
+ diffusion pipeline's components.
+
+ """
+ config_name = "model_index.json"
+
+ def __init__(self, *args, **kwargs):
+ raise EnvironmentError(
+ f"{self.__class__.__name__} is designed to be instantiated "
+ f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
+ f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
+ )
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
+ r"""
+ Instantiates a text-to-image Pytorch diffusion pipeline from pretrained pipeline weight.
+
+ The from_pretrained() method takes care of returning the correct pipeline class instance by:
+ 1. Detect the pipeline class of the pretrained_model_or_path based on the _class_name property of its
+ config object
+ 2. Find the text-to-image pipeline linked to the pipeline class using pattern matching on pipeline class
+ name.
+
+ If a `controlnet` argument is passed, it will instantiate a [`StableDiffusionControlNetPipeline`] object.
+
+ The pipeline is set in evaluation mode (`model.eval()`) by default.
+
+ If you get the error message below, you need to finetune the weights for your downstream task:
+
+ ```
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
+ ```
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
+ hosted on the Hub.
+ - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
+ saved using
+ [`~DiffusionPipeline.save_pretrained`].
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
+ dtype is automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ custom_revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
+ `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
+ custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
+ mirror (`str`, *optional*):
+ Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
+ information.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn’t need to be defined for each
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
+ each GPU and the available CPU RAM if unset.
+ offload_folder (`str` or `os.PathLike`, *optional*):
+ The path to offload weights if device_map contains the value `"disk"`.
+ offload_state_dict (`bool`, *optional*):
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
+ when there is some disk offload.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the safetensors weights are downloaded if they're available **and** if the
+ safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
+ weights. If set to `False`, safetensors weights are not loaded.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
+ class). The overwritten components are passed directly to the pipelines `__init__` method. See example
+ below for more information.
+ variant (`str`, *optional*):
+ Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
+ loading `from_flax`.
+
+
+
+ To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ `huggingface-cli login`.
+
+
+
+ Examples:
+
+ ```py
+ >>> from diffusers import AutoPipelineForTextToImage
+
+ >>> pipeline = AutoPipelineForTextToImage.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> print(pipeline.__class__)
+ ```
+ """
+ config = cls.load_config(pretrained_model_or_path)
+ orig_class_name = config["_class_name"]
+
+ if "controlnet" in kwargs:
+ orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
+
+ text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name)
+
+ return text_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs)
+
+ @classmethod
+ def from_pipe(cls, pipeline, **kwargs):
+ r"""
+ Instantiates a text-to-image Pytorch diffusion pipeline from another instantiated diffusion pipeline class.
+
+ The from_pipe() method takes care of returning the correct pipeline class instance by finding the text-to-image
+ pipeline linked to the pipeline class using pattern matching on pipeline class name.
+
+ All the modules the pipeline contains will be used to initialize the new pipeline without reallocating
+ additional memoery.
+
+ The pipeline is set in evaluation mode (`model.eval()`) by default.
+
+ Parameters:
+ pipeline (`DiffusionPipeline`):
+ an instantiated `DiffusionPipeline` object
+
+ ```py
+ >>> from diffusers import AutoPipelineForTextToImage, AutoPipelineForImageToImage
+
+ >>> pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5", requires_safety_checker=False
+ ... )
+
+ >>> pipe_t2i = AutoPipelineForTextToImage.from_pipe(pipe_t2i)
+ ```
+ """
+
+ original_config = dict(pipeline.config)
+ original_cls_name = pipeline.__class__.__name__
+
+ # derive the pipeline class to instantiate
+ text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, original_cls_name)
+
+ # define expected module and optional kwargs given the pipeline signature
+ expected_modules, optional_kwargs = _get_signature_keys(text_2_image_cls)
+
+ pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
+
+ # allow users pass modules in `kwargs` to override the original pipeline's components
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
+ original_class_obj = {
+ k: pipeline.components[k]
+ for k, v in pipeline.components.items()
+ if k in expected_modules and k not in passed_class_obj
+ }
+
+ # allow users pass optional kwargs to override the original pipelines config attribute
+ passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
+ original_pipe_kwargs = {
+ k: original_config[k]
+ for k, v in original_config.items()
+ if k in optional_kwargs and k not in passed_pipe_kwargs
+ }
+
+ # config that were not expected by original pipeline is stored as private attribute
+ # we will pass them as optional arguments if they can be accepted by the pipeline
+ additional_pipe_kwargs = [
+ k[1:]
+ for k in original_config.keys()
+ if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs
+ ]
+ for k in additional_pipe_kwargs:
+ original_pipe_kwargs[k] = original_config.pop(f"_{k}")
+
+ text_2_image_kwargs = {**passed_class_obj, **original_class_obj, **passed_pipe_kwargs, **original_pipe_kwargs}
+
+ # store unused config as private attribute
+ unused_original_config = {
+ f"{'' if k.startswith('_') else '_'}{k}": original_config[k]
+ for k, v in original_config.items()
+ if k not in text_2_image_kwargs
+ }
+
+ missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(text_2_image_kwargs.keys())
+
+ if len(missing_modules) > 0:
+ raise ValueError(
+ f"Pipeline {text_2_image_cls} expected {expected_modules}, but only {set(passed_class_obj.keys()) + set(original_class_obj.keys())} were passed"
+ )
+
+ model = text_2_image_cls(**text_2_image_kwargs)
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
+ model.register_to_config(**unused_original_config)
+
+ return model
+
+
+class AutoPipelineForImage2Image(ConfigMixin):
+ r"""
+
+ AutoPipeline for image-to-image generation.
+
+ [`AutoPipelineForImage2Image`] is a generic pipeline class that will be instantiated as one of the image-to-image
+ pipeline classes in diffusers.
+
+ The pipeline type (for example [`StableDiffusionImg2ImgPipeline`]) is automatically selected when created with the
+ `AutoPipelineForImage2Image.from_pretrained(pretrained_model_name_or_path)` or
+ `AutoPipelineForImage2Image.from_pipe(pipeline)` class methods.
+
+ This class cannot be instantiated using __init__() (throws an error).
+
+ Class attributes:
+
+ - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the
+ diffusion pipeline's components.
+
+ """
+ config_name = "model_index.json"
+
+ def __init__(self, *args, **kwargs):
+ raise EnvironmentError(
+ f"{self.__class__.__name__} is designed to be instantiated "
+ f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
+ f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
+ )
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
+ r"""
+ Instantiates a image-to-image Pytorch diffusion pipeline from pretrained pipeline weight.
+
+ The from_pretrained() method takes care of returning the correct pipeline class instance by:
+ 1. Detect the pipeline class of the pretrained_model_or_path based on the _class_name property of its
+ config object
+ 2. Find the image-to-image pipeline linked to the pipeline class using pattern matching on pipeline class
+ name.
+
+ If a `controlnet` argument is passed, it will instantiate a StableDiffusionControlNetImg2ImgPipeline object.
+
+ The pipeline is set in evaluation mode (`model.eval()`) by default.
+
+ If you get the error message below, you need to finetune the weights for your downstream task:
+
+ ```
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
+ ```
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
+ hosted on the Hub.
+ - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
+ saved using
+ [`~DiffusionPipeline.save_pretrained`].
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
+ dtype is automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ custom_revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
+ `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
+ custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
+ mirror (`str`, *optional*):
+ Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
+ information.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn’t need to be defined for each
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
+ each GPU and the available CPU RAM if unset.
+ offload_folder (`str` or `os.PathLike`, *optional*):
+ The path to offload weights if device_map contains the value `"disk"`.
+ offload_state_dict (`bool`, *optional*):
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
+ when there is some disk offload.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the safetensors weights are downloaded if they're available **and** if the
+ safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
+ weights. If set to `False`, safetensors weights are not loaded.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
+ class). The overwritten components are passed directly to the pipelines `__init__` method. See example
+ below for more information.
+ variant (`str`, *optional*):
+ Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
+ loading `from_flax`.
+
+
+
+ To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ `huggingface-cli login`.
+
+
+
+ Examples:
+
+ ```py
+ >>> from diffusers import AutoPipelineForTextToImage
+
+ >>> pipeline = AutoPipelineForImageToImage.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> print(pipeline.__class__)
+ ```
+ """
+ config = cls.load_config(pretrained_model_or_path)
+ orig_class_name = config["_class_name"]
+
+ if "controlnet" in kwargs:
+ orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
+
+ image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
+
+ return image_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs)
+
+ @classmethod
+ def from_pipe(cls, pipeline, **kwargs):
+ r"""
+ Instantiates a image-to-image Pytorch diffusion pipeline from another instantiated diffusion pipeline class.
+
+ The from_pipe() method takes care of returning the correct pipeline class instance by finding the
+ image-to-image pipeline linked to the pipeline class using pattern matching on pipeline class name.
+
+ All the modules the pipeline contains will be used to initialize the new pipeline without reallocating
+ additional memoery.
+
+ The pipeline is set in evaluation mode (`model.eval()`) by default.
+
+ Parameters:
+ pipeline (`DiffusionPipeline`):
+ an instantiated `DiffusionPipeline` object
+
+ Examples:
+
+ ```py
+ >>> from diffusers import AutoPipelineForTextToImage, AutoPipelineForImageToImage
+
+ >>> pipe_t2i = AutoPipelineForText2Image.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5", requires_safety_checker=False
+ ... )
+
+ >>> pipe_i2i = AutoPipelineForImageToImage.from_pipe(pipe_t2i)
+ ```
+ """
+
+ original_config = dict(pipeline.config)
+ original_cls_name = pipeline.__class__.__name__
+
+ # derive the pipeline class to instantiate
+ image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, original_cls_name)
+
+ # define expected module and optional kwargs given the pipeline signature
+ expected_modules, optional_kwargs = _get_signature_keys(image_2_image_cls)
+
+ pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
+
+ # allow users pass modules in `kwargs` to override the original pipeline's components
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
+ original_class_obj = {
+ k: pipeline.components[k]
+ for k, v in pipeline.components.items()
+ if k in expected_modules and k not in passed_class_obj
+ }
+
+ # allow users pass optional kwargs to override the original pipelines config attribute
+ passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
+ original_pipe_kwargs = {
+ k: original_config[k]
+ for k, v in original_config.items()
+ if k in optional_kwargs and k not in passed_pipe_kwargs
+ }
+
+ # config attribute that were not expected by original pipeline is stored as its private attribute
+ # we will pass them as optional arguments if they can be accepted by the pipeline
+ additional_pipe_kwargs = [
+ k[1:]
+ for k in original_config.keys()
+ if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs
+ ]
+ for k in additional_pipe_kwargs:
+ original_pipe_kwargs[k] = original_config.pop(f"_{k}")
+
+ image_2_image_kwargs = {**passed_class_obj, **original_class_obj, **passed_pipe_kwargs, **original_pipe_kwargs}
+
+ # store unused config as private attribute
+ unused_original_config = {
+ f"{'' if k.startswith('_') else '_'}{k}": original_config[k]
+ for k, v in original_config.items()
+ if k not in image_2_image_kwargs
+ }
+
+ missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(image_2_image_kwargs.keys())
+
+ if len(missing_modules) > 0:
+ raise ValueError(
+ f"Pipeline {image_2_image_cls} expected {expected_modules}, but only {set(passed_class_obj.keys()) + set(original_class_obj.keys())} were passed"
+ )
+
+ model = image_2_image_cls(**image_2_image_kwargs)
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
+ model.register_to_config(**unused_original_config)
+
+ return model
+
+
+class AutoPipelineForInpainting(ConfigMixin):
+ r"""
+
+ AutoPipeline for inpainting generation.
+
+ [`AutoPipelineForInpainting`] is a generic pipeline class that will be instantiated as one of the inpainting
+ pipeline class in diffusers.
+
+ The pipeline type (for example [`IFInpaintingPipeline`]) is automatically selected when created with the
+ AutoPipelineForInpainting.from_pretrained(pretrained_model_name_or_path) or
+ AutoPipelineForInpainting.from_pipe(pipeline) class methods .
+
+ This class cannot be instantiated using __init__() (throws an error).
+
+ Class attributes:
+
+ - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the
+ diffusion pipeline's components.
+
+ """
+ config_name = "model_index.json"
+
+ def __init__(self, *args, **kwargs):
+ raise EnvironmentError(
+ f"{self.__class__.__name__} is designed to be instantiated "
+ f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
+ f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
+ )
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
+ r"""
+ Instantiates a inpainting Pytorch diffusion pipeline from pretrained pipeline weight.
+
+ The from_pretrained() method takes care of returning the correct pipeline class instance by:
+ 1. Detect the pipeline class of the pretrained_model_or_path based on the _class_name property of its
+ config object
+ 2. Find the inpainting pipeline linked to the pipeline class using pattern matching on pipeline class name.
+
+ If a `controlnet` argument is passed, it will instantiate a StableDiffusionControlNetInpaintPipeline object.
+
+ The pipeline is set in evaluation mode (`model.eval()`) by default.
+
+ If you get the error message below, you need to finetune the weights for your downstream task:
+
+ ```
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
+ ```
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
+ hosted on the Hub.
+ - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
+ saved using
+ [`~DiffusionPipeline.save_pretrained`].
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
+ dtype is automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ custom_revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
+ `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
+ custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
+ mirror (`str`, *optional*):
+ Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
+ information.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn’t need to be defined for each
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
+ each GPU and the available CPU RAM if unset.
+ offload_folder (`str` or `os.PathLike`, *optional*):
+ The path to offload weights if device_map contains the value `"disk"`.
+ offload_state_dict (`bool`, *optional*):
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
+ when there is some disk offload.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the safetensors weights are downloaded if they're available **and** if the
+ safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
+ weights. If set to `False`, safetensors weights are not loaded.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
+ class). The overwritten components are passed directly to the pipelines `__init__` method. See example
+ below for more information.
+ variant (`str`, *optional*):
+ Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
+ loading `from_flax`.
+
+
+
+ To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ `huggingface-cli login`.
+
+
+
+ Examples:
+
+ ```py
+ >>> from diffusers import AutoPipelineForTextToImage
+
+ >>> pipeline = AutoPipelineForImageToImage.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> print(pipeline.__class__)
+ ```
+ """
+ config = cls.load_config(pretrained_model_or_path)
+ orig_class_name = config["_class_name"]
+
+ if "controlnet" in kwargs:
+ orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
+
+ inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
+
+ return inpainting_cls.from_pretrained(pretrained_model_or_path, **kwargs)
+
+ @classmethod
+ def from_pipe(cls, pipeline, **kwargs):
+ r"""
+ Instantiates a inpainting Pytorch diffusion pipeline from another instantiated diffusion pipeline class.
+
+ The from_pipe() method takes care of returning the correct pipeline class instance by finding the inpainting
+ pipeline linked to the pipeline class using pattern matching on pipeline class name.
+
+ All the modules the pipeline class contain will be used to initialize the new pipeline without reallocating
+ additional memoery.
+
+ The pipeline is set in evaluation mode (`model.eval()`) by default.
+
+ Parameters:
+ pipeline (`DiffusionPipeline`):
+ an instantiated `DiffusionPipeline` object
+
+ Examples:
+
+ ```py
+ >>> from diffusers import AutoPipelineForTextToImage, AutoPipelineForInpainting
+
+ >>> pipe_t2i = AutoPipelineForText2Image.from_pretrained(
+ ... "DeepFloyd/IF-I-XL-v1.0", requires_safety_checker=False
+ ... )
+
+ >>> pipe_inpaint = AutoPipelineForInpainting.from_pipe(pipe_t2i)
+ ```
+ """
+ original_config = dict(pipeline.config)
+ original_cls_name = pipeline.__class__.__name__
+
+ # derive the pipeline class to instantiate
+ inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, original_cls_name)
+
+ # define expected module and optional kwargs given the pipeline signature
+ expected_modules, optional_kwargs = _get_signature_keys(inpainting_cls)
+
+ pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
+
+ # allow users pass modules in `kwargs` to override the original pipeline's components
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
+ original_class_obj = {
+ k: pipeline.components[k]
+ for k, v in pipeline.components.items()
+ if k in expected_modules and k not in passed_class_obj
+ }
+
+ # allow users pass optional kwargs to override the original pipelines config attribute
+ passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
+ original_pipe_kwargs = {
+ k: original_config[k]
+ for k, v in original_config.items()
+ if k in optional_kwargs and k not in passed_pipe_kwargs
+ }
+
+ # config that were not expected by original pipeline is stored as private attribute
+ # we will pass them as optional arguments if they can be accepted by the pipeline
+ additional_pipe_kwargs = [
+ k[1:]
+ for k in original_config.keys()
+ if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs
+ ]
+ for k in additional_pipe_kwargs:
+ original_pipe_kwargs[k] = original_config.pop(f"_{k}")
+
+ inpainting_kwargs = {**passed_class_obj, **original_class_obj, **passed_pipe_kwargs, **original_pipe_kwargs}
+
+ # store unused config as private attribute
+ unused_original_config = {
+ f"{'' if k.startswith('_') else '_'}{k}": original_config[k]
+ for k, v in original_config.items()
+ if k not in inpainting_kwargs
+ }
+
+ missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(inpainting_kwargs.keys())
+
+ if len(missing_modules) > 0:
+ raise ValueError(
+ f"Pipeline {inpainting_cls} expected {expected_modules}, but only {set(passed_class_obj.keys()) + set(original_class_obj.keys())} were passed"
+ )
+
+ model = inpainting_cls(**inpainting_kwargs)
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
+ model.register_to_config(**unused_original_config)
+
+ return model
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
index 31edb915e8f7..dcb6d07939b4 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
@@ -20,8 +20,6 @@
)
from ...models import UNet2DConditionModel, VQModel
-from ...pipelines import DiffusionPipeline
-from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import DDIMScheduler, DDPMScheduler
from ...utils import (
is_accelerate_available,
@@ -30,6 +28,7 @@
randn_tensor,
replace_example_docstring,
)
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
index e3e4b04febac..2cec7c4d663c 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
@@ -23,8 +23,6 @@
)
from ...models import UNet2DConditionModel, VQModel
-from ...pipelines import DiffusionPipeline
-from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import DDIMScheduler
from ...utils import (
is_accelerate_available,
@@ -33,6 +31,7 @@
randn_tensor,
replace_example_docstring,
)
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
index a4a7dc225947..eabbbb27ea5e 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
@@ -25,8 +25,6 @@
)
from ...models import UNet2DConditionModel, VQModel
-from ...pipelines import DiffusionPipeline
-from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import DDIMScheduler
from ...utils import (
is_accelerate_available,
@@ -35,6 +33,7 @@
randn_tensor,
replace_example_docstring,
)
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
index 4d99180b3d11..bf75eeacfdf3 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
@@ -21,7 +21,6 @@
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
from ...models import PriorTransformer
-from ...pipelines import DiffusionPipeline
from ...schedulers import UnCLIPScheduler
from ...utils import (
BaseOutput,
@@ -29,6 +28,7 @@
randn_tensor,
replace_example_docstring,
)
+from ..pipeline_utils import DiffusionPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
index 6db5260b04d8..007700baf627 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
@@ -17,8 +17,6 @@
import torch
from ...models import UNet2DConditionModel, VQModel
-from ...pipelines import DiffusionPipeline
-from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import DDPMScheduler
from ...utils import (
is_accelerate_available,
@@ -27,6 +25,7 @@
randn_tensor,
replace_example_docstring,
)
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
index d6fc26911747..6d616f8c760c 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
@@ -17,8 +17,6 @@
import torch
from ...models import UNet2DConditionModel, VQModel
-from ...pipelines import DiffusionPipeline
-from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import DDPMScheduler
from ...utils import (
is_accelerate_available,
@@ -27,6 +25,7 @@
randn_tensor,
replace_example_docstring,
)
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py
index 8a25624b7267..fbfcf91569ef 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py
@@ -20,8 +20,6 @@
from PIL import Image
from ...models import UNet2DConditionModel, VQModel
-from ...pipelines import DiffusionPipeline
-from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import DDPMScheduler
from ...utils import (
is_accelerate_available,
@@ -30,6 +28,7 @@
randn_tensor,
replace_example_docstring,
)
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
index 26976ad0c925..9ecea9b8393d 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
@@ -20,8 +20,6 @@
from PIL import Image
from ...models import UNet2DConditionModel, VQModel
-from ...pipelines import DiffusionPipeline
-from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import DDPMScheduler
from ...utils import (
is_accelerate_available,
@@ -30,6 +28,7 @@
randn_tensor,
replace_example_docstring,
)
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
index 27bd01984dbf..f4e6349665e3 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
@@ -22,8 +22,6 @@
from PIL import Image
from ...models import UNet2DConditionModel, VQModel
-from ...pipelines import DiffusionPipeline
-from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import DDPMScheduler
from ...utils import (
is_accelerate_available,
@@ -32,6 +30,7 @@
randn_tensor,
replace_example_docstring,
)
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
index 1bfc6523cdf9..b6ab2ca3fc23 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
@@ -5,7 +5,6 @@
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
from ...models import PriorTransformer
-from ...pipelines import DiffusionPipeline
from ...schedulers import UnCLIPScheduler
from ...utils import (
logging,
@@ -13,6 +12,7 @@
replace_example_docstring,
)
from ..kandinsky import KandinskyPriorPipelineOutput
+from ..pipeline_utils import DiffusionPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
index bcbeb78bac10..75be6e54c93f 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
@@ -5,7 +5,6 @@
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
from ...models import PriorTransformer
-from ...pipelines import DiffusionPipeline
from ...schedulers import UnCLIPScheduler
from ...utils import (
logging,
@@ -13,6 +12,7 @@
replace_example_docstring,
)
from ..kandinsky import KandinskyPriorPipelineOutput
+from ..pipeline_utils import DiffusionPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py
index fdcbe550860a..a9fab3f3a979 100644
--- a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py
+++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py
@@ -22,7 +22,6 @@
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...models import PriorTransformer
-from ...pipelines import DiffusionPipeline
from ...schedulers import HeunDiscreteScheduler
from ...utils import (
BaseOutput,
@@ -32,6 +31,7 @@
randn_tensor,
replace_example_docstring,
)
+from ..pipeline_utils import DiffusionPipeline
from .renderer import ShapERenderer
diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py
index 08c585c5ad73..8693641d5cef 100644
--- a/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py
+++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py
@@ -21,7 +21,6 @@
from transformers import CLIPImageProcessor, CLIPVisionModel
from ...models import PriorTransformer
-from ...pipelines import DiffusionPipeline
from ...schedulers import HeunDiscreteScheduler
from ...utils import (
BaseOutput,
@@ -29,6 +28,7 @@
randn_tensor,
replace_example_docstring,
)
+from ..pipeline_utils import DiffusionPipeline
from .renderer import ShapERenderer
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
index 29a57470a341..16197def96ff 100755
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
@@ -23,9 +23,9 @@
from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
-from ...pipelines import DiffusionPipeline
from ...schedulers import LMSDiscreteScheduler
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor
+from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py
index f3b67bebfc8d..1d400a23c972 100644
--- a/src/diffusers/pipelines/unclip/pipeline_unclip.py
+++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py
@@ -21,10 +21,9 @@
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
-from ...pipelines import DiffusionPipeline
-from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import UnCLIPScheduler
from ...utils import logging, randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_proj import UnCLIPTextProjModel
diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py
index 580417f517be..e24c6af15a01 100644
--- a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py
+++ b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py
@@ -26,9 +26,9 @@
)
from ...models import UNet2DConditionModel, UNet2DModel
-from ...pipelines import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import UnCLIPScheduler
from ...utils import logging, randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_proj import UnCLIPTextProjModel
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index b955ec5320de..a0e3bee76aa9 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -240,6 +240,51 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class AutoPipelineForImage2Image(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoPipelineForInpainting(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoPipelineForText2Image(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class ConsistencyModelPipeline(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/tests/pipelines/test_pipelines_auto.py b/tests/pipelines/test_pipelines_auto.py
new file mode 100644
index 000000000000..595a7a5f25ff
--- /dev/null
+++ b/tests/pipelines/test_pipelines_auto.py
@@ -0,0 +1,201 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+from collections import OrderedDict
+
+import torch
+
+from diffusers import (
+ AutoPipelineForImage2Image,
+ AutoPipelineForInpainting,
+ AutoPipelineForText2Image,
+ ControlNetModel,
+)
+from diffusers.pipelines.auto_pipeline import (
+ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
+ AUTO_INPAINT_PIPELINES_MAPPING,
+ AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
+)
+from diffusers.utils import slow
+
+
+PRETRAINED_MODEL_REPO_MAPPING = OrderedDict(
+ [
+ ("stable-diffusion", "runwayml/stable-diffusion-v1-5"),
+ ("if", "DeepFloyd/IF-I-XL-v1.0"),
+ ("kandinsky", "kandinsky-community/kandinsky-2-1"),
+ ("kandinsky22", "kandinsky-community/kandinsky-2-2-decoder"),
+ ]
+)
+
+
+class AutoPipelineFastTest(unittest.TestCase):
+ def test_from_pipe_consistent(self):
+ pipe = AutoPipelineForText2Image.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-pipe", requires_safety_checker=False
+ )
+ original_config = dict(pipe.config)
+
+ pipe = AutoPipelineForImage2Image.from_pipe(pipe)
+ assert dict(pipe.config) == original_config
+
+ pipe = AutoPipelineForText2Image.from_pipe(pipe)
+ assert dict(pipe.config) == original_config
+
+ def test_from_pipe_override(self):
+ pipe = AutoPipelineForText2Image.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-pipe", requires_safety_checker=False
+ )
+
+ pipe = AutoPipelineForImage2Image.from_pipe(pipe, requires_safety_checker=True)
+ assert pipe.config.requires_safety_checker is True
+
+ pipe = AutoPipelineForText2Image.from_pipe(pipe, requires_safety_checker=True)
+ assert pipe.config.requires_safety_checker is True
+
+ def test_from_pipe_consistent_sdxl(self):
+ pipe = AutoPipelineForImage2Image.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-xl-pipe",
+ requires_aesthetics_score=True,
+ force_zeros_for_empty_prompt=False,
+ )
+
+ original_config = dict(pipe.config)
+
+ pipe = AutoPipelineForText2Image.from_pipe(pipe)
+ pipe = AutoPipelineForImage2Image.from_pipe(pipe)
+
+ assert dict(pipe.config) == original_config
+
+
+@slow
+class AutoPipelineIntegrationTest(unittest.TestCase):
+ def test_pipe_auto(self):
+ for model_name, model_repo in PRETRAINED_MODEL_REPO_MAPPING.items():
+ # test txt2img
+ pipe_txt2img = AutoPipelineForText2Image.from_pretrained(
+ model_repo, variant="fp16", torch_dtype=torch.float16
+ )
+ self.assertIsInstance(pipe_txt2img, AUTO_TEXT2IMAGE_PIPELINES_MAPPING[model_name])
+
+ pipe_to = AutoPipelineForText2Image.from_pipe(pipe_txt2img)
+ self.assertIsInstance(pipe_to, AUTO_TEXT2IMAGE_PIPELINES_MAPPING[model_name])
+
+ pipe_to = AutoPipelineForImage2Image.from_pipe(pipe_txt2img)
+ self.assertIsInstance(pipe_to, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING[model_name])
+
+ if "kandinsky" not in model_name:
+ pipe_to = AutoPipelineForInpainting.from_pipe(pipe_txt2img)
+ self.assertIsInstance(pipe_to, AUTO_INPAINT_PIPELINES_MAPPING[model_name])
+
+ del pipe_txt2img, pipe_to
+ gc.collect()
+
+ # test img2img
+
+ pipe_img2img = AutoPipelineForImage2Image.from_pretrained(
+ model_repo, variant="fp16", torch_dtype=torch.float16
+ )
+ self.assertIsInstance(pipe_img2img, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING[model_name])
+
+ pipe_to = AutoPipelineForText2Image.from_pipe(pipe_img2img)
+ self.assertIsInstance(pipe_to, AUTO_TEXT2IMAGE_PIPELINES_MAPPING[model_name])
+
+ pipe_to = AutoPipelineForImage2Image.from_pipe(pipe_img2img)
+ self.assertIsInstance(pipe_to, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING[model_name])
+
+ if "kandinsky" not in model_name:
+ pipe_to = AutoPipelineForInpainting.from_pipe(pipe_img2img)
+ self.assertIsInstance(pipe_to, AUTO_INPAINT_PIPELINES_MAPPING[model_name])
+
+ del pipe_img2img, pipe_to
+ gc.collect()
+
+ # test inpaint
+
+ if "kandinsky" not in model_name:
+ pipe_inpaint = AutoPipelineForInpainting.from_pretrained(
+ model_repo, variant="fp16", torch_dtype=torch.float16
+ )
+ self.assertIsInstance(pipe_inpaint, AUTO_INPAINT_PIPELINES_MAPPING[model_name])
+
+ pipe_to = AutoPipelineForText2Image.from_pipe(pipe_inpaint)
+ self.assertIsInstance(pipe_to, AUTO_TEXT2IMAGE_PIPELINES_MAPPING[model_name])
+
+ pipe_to = AutoPipelineForImage2Image.from_pipe(pipe_inpaint)
+ self.assertIsInstance(pipe_to, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING[model_name])
+
+ pipe_to = AutoPipelineForInpainting.from_pipe(pipe_inpaint)
+ self.assertIsInstance(pipe_to, AUTO_INPAINT_PIPELINES_MAPPING[model_name])
+
+ del pipe_inpaint, pipe_to
+ gc.collect()
+
+ def test_from_pipe_consistent(self):
+ for model_name, model_repo in PRETRAINED_MODEL_REPO_MAPPING.items():
+ if model_name in ["kandinsky", "kandinsky22"]:
+ auto_pipes = [AutoPipelineForText2Image, AutoPipelineForImage2Image]
+ else:
+ auto_pipes = [AutoPipelineForText2Image, AutoPipelineForImage2Image, AutoPipelineForInpainting]
+
+ # test from_pretrained
+ for pipe_from_class in auto_pipes:
+ pipe_from = pipe_from_class.from_pretrained(model_repo, variant="fp16", torch_dtype=torch.float16)
+ pipe_from_config = dict(pipe_from.config)
+
+ for pipe_to_class in auto_pipes:
+ pipe_to = pipe_to_class.from_pipe(pipe_from)
+ self.assertEqual(dict(pipe_to.config), pipe_from_config)
+
+ del pipe_from, pipe_to
+ gc.collect()
+
+ def test_controlnet(self):
+ # test from_pretrained
+ model_repo = "runwayml/stable-diffusion-v1-5"
+ controlnet_repo = "lllyasviel/sd-controlnet-canny"
+
+ controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=torch.float16)
+
+ pipe_txt2img = AutoPipelineForText2Image.from_pretrained(
+ model_repo, controlnet=controlnet, torch_dtype=torch.float16
+ )
+ self.assertIsInstance(pipe_txt2img, AUTO_TEXT2IMAGE_PIPELINES_MAPPING["stable-diffusion-controlnet"])
+
+ pipe_img2img = AutoPipelineForImage2Image.from_pretrained(
+ model_repo, controlnet=controlnet, torch_dtype=torch.float16
+ )
+ self.assertIsInstance(pipe_img2img, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["stable-diffusion-controlnet"])
+
+ pipe_inpaint = AutoPipelineForInpainting.from_pretrained(
+ model_repo, controlnet=controlnet, torch_dtype=torch.float16
+ )
+ self.assertIsInstance(pipe_inpaint, AUTO_INPAINT_PIPELINES_MAPPING["stable-diffusion-controlnet"])
+
+ # test from_pipe
+ for pipe_from in [pipe_txt2img, pipe_img2img, pipe_inpaint]:
+ pipe_to = AutoPipelineForText2Image.from_pipe(pipe_from)
+ self.assertIsInstance(pipe_to, AUTO_TEXT2IMAGE_PIPELINES_MAPPING["stable-diffusion-controlnet"])
+ self.assertEqual(dict(pipe_to.config), dict(pipe_txt2img.config))
+
+ pipe_to = AutoPipelineForImage2Image.from_pipe(pipe_from)
+ self.assertIsInstance(pipe_to, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["stable-diffusion-controlnet"])
+ self.assertEqual(dict(pipe_to.config), dict(pipe_img2img.config))
+
+ pipe_to = AutoPipelineForInpainting.from_pipe(pipe_from)
+ self.assertIsInstance(pipe_to, AUTO_INPAINT_PIPELINES_MAPPING["stable-diffusion-controlnet"])
+ self.assertEqual(dict(pipe_to.config), dict(pipe_inpaint.config))