diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 4cf4cf36e4e4..5abee475d9c9 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -186,6 +186,8 @@ title: Audio Diffusion - local: api/pipelines/audioldm title: AudioLDM + - local: api/pipelines/auto_pipeline + title: AutoPipeline - local: api/pipelines/consistency_models title: Consistency Models - local: api/pipelines/controlnet diff --git a/docs/source/en/api/pipelines/auto_pipeline.mdx b/docs/source/en/api/pipelines/auto_pipeline.mdx new file mode 100644 index 000000000000..4ae2b86ac269 --- /dev/null +++ b/docs/source/en/api/pipelines/auto_pipeline.mdx @@ -0,0 +1,68 @@ + + +# AutoPipeline + +In many cases, one checkpoint can be used for multiple tasks. For example, you may be able to use the same checkpoint for Text-to-Image, Image-to-Image, and Inpainting. However, you'll need to know the pipeline class names linked to your checkpoint. + +AutoPipeline is designed to make it easy for you to use multiple pipelines in your workflow. We currently provide 3 AutoPipeline classes to perform three different tasks, i.e. [`AutoPipelineForText2Image`], [`AutoPipelineForImage2Image`], and [`AutoPipelineForInpainting`]. You'll need to choose the AutoPipeline class based on the task you want to perform and use it to automatically retrieve the relevant pipeline given the name/path to the pre-trained weights. + +For example, to perform Image-to-Image with the SD1.5 checkpoint, you can do + +```python +from diffusers import PipelineForImageToImage + +pipe_i2i = PipelineForImageoImage.from_pretrained("runwayml/stable-diffusion-v1-5") +``` + +It will also help you switch between tasks seamlessly using the same checkpoint without reallocating additional memory. For example, to re-use the Image-to-Image pipeline we just created for inpainting, you can do + +```python +from diffusers import PipelineForInpainting + +pipe_inpaint = AutoPipelineForInpainting.from_pipe(pipe_i2i) +``` +All the components will be transferred to the inpainting pipeline with zero cost. + + +Currently AutoPipeline support the Text-to-Image, Image-to-Image, and Inpainting tasks for below diffusion models: +- [stable Diffusion](./stable_diffusion) +- [Stable Diffusion Controlnet](./api/pipelines/controlnet) +- [Stable Diffusion XL](./stable_diffusion/stable_diffusion_xl) +- [IF](./if) +- [Kandinsky](./kandinsky)(./kandinsky)(./kandinsky)(./kandinsky)(./kandinsky) +- [Kandinsky 2.2]()(./kandinsky) + + +## AutoPipelineForText2Image + +[[autodoc]] AutoPipelineForText2Image + - all + - from_pretrained + - from_pipe + + +## AutoPipelineForImage2Image + +[[autodoc]] AutoPipelineForImage2Image + - all + - from_pretrained + - from_pipe + +## AutoPipelineForInpainting + +[[autodoc]] AutoPipelineForInpainting + - all + - from_pretrained + - from_pipe + + diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 445e63b4d406..81b863c7e65e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -61,6 +61,9 @@ ) from .pipelines import ( AudioPipelineOutput, + AutoPipelineForImage2Image, + AutoPipelineForInpainting, + AutoPipelineForText2Image, ConsistencyModelPipeline, DanceDiffusionPipeline, DDIMPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 802ae4f5bc94..cd667292f907 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -17,6 +17,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: + from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image from .consistency_models import ConsistencyModelPipeline from .dance_diffusion import DanceDiffusionPipeline from .ddim import DDIMPipeline diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py new file mode 100644 index 000000000000..c827231ada7d --- /dev/null +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -0,0 +1,834 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from collections import OrderedDict + +from ..configuration_utils import ConfigMixin +from .controlnet import ( + StableDiffusionControlNetImg2ImgPipeline, + StableDiffusionControlNetInpaintPipeline, + StableDiffusionControlNetPipeline, + StableDiffusionXLControlNetPipeline, +) +from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline +from .kandinsky import KandinskyImg2ImgPipeline, KandinskyInpaintPipeline, KandinskyPipeline +from .kandinsky2_2 import KandinskyV22Img2ImgPipeline, KandinskyV22InpaintPipeline, KandinskyV22Pipeline +from .stable_diffusion import ( + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, +) +from .stable_diffusion_xl import ( + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLPipeline, +) + + +AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( + [ + ("stable-diffusion", StableDiffusionPipeline), + ("stable-diffusion-xl", StableDiffusionXLPipeline), + ("if", IFPipeline), + ("kandinsky", KandinskyPipeline), + ("kandinsky22", KandinskyV22Pipeline), + ("stable-diffusion-controlnet", StableDiffusionControlNetPipeline), + ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline), + ] +) + +AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict( + [ + ("stable-diffusion", StableDiffusionImg2ImgPipeline), + ("stable-diffusion-xl", StableDiffusionXLImg2ImgPipeline), + ("if", IFImg2ImgPipeline), + ("kandinsky", KandinskyImg2ImgPipeline), + ("kandinsky22", KandinskyV22Img2ImgPipeline), + ("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline), + ] +) + +AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict( + [ + ("stable-diffusion", StableDiffusionInpaintPipeline), + ("stable-diffusion-xl", StableDiffusionXLInpaintPipeline), + ("if", IFInpaintingPipeline), + ("kandinsky", KandinskyInpaintPipeline), + ("kandinsky22", KandinskyV22InpaintPipeline), + ("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline), + ] +) + +SUPPORTED_TASKS_MAPPINGS = [ + AUTO_TEXT2IMAGE_PIPELINES_MAPPING, + AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, + AUTO_INPAINT_PIPELINES_MAPPING, +] + + +def _get_task_class(mapping, pipeline_class_name): + def get_model(pipeline_class_name): + for task_mapping in SUPPORTED_TASKS_MAPPINGS: + for model_name, pipeline in task_mapping.items(): + if pipeline.__name__ == pipeline_class_name: + return model_name + + model_name = get_model(pipeline_class_name) + + if model_name is not None: + task_class = mapping.get(model_name, None) + if task_class is not None: + return task_class + raise ValueError(f"AutoPipeline can't find a pipeline linked to {pipeline_class_name} for {model_name}") + + +def _get_signature_keys(obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - {"self"} + return expected_modules, optional_parameters + + +class AutoPipelineForText2Image(ConfigMixin): + r""" + + AutoPipeline for text-to-image generation. + + [`AutoPipelineForText2Image`] is a generic pipeline class that will be instantiated as one of the text-to-image + pipeline class in diffusers. + + The pipeline type (for example [`StableDiffusionPipeline`]) is automatically selected when created with the + AutoPipelineForText2Image.from_pretrained(pretrained_model_name_or_path) or + AutoPipelineForText2Image.from_pipe(pipeline) class methods . + + This class cannot be instantiated using __init__() (throws an error). + + Class attributes: + + - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the + diffusion pipeline's components. + + """ + config_name = "model_index.json" + + def __init__(self, *args, **kwargs): + raise EnvironmentError( + f"{self.__class__.__name__} is designed to be instantiated " + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " + f"`{self.__class__.__name__}.from_pipe(pipeline)` methods." + ) + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, **kwargs): + r""" + Instantiates a text-to-image Pytorch diffusion pipeline from pretrained pipeline weight. + + The from_pretrained() method takes care of returning the correct pipeline class instance by: + 1. Detect the pipeline class of the pretrained_model_or_path based on the _class_name property of its + config object + 2. Find the text-to-image pipeline linked to the pipeline class using pattern matching on pipeline class + name. + + If a `controlnet` argument is passed, it will instantiate a [`StableDiffusionControlNetPipeline`] object. + + The pipeline is set in evaluation mode (`model.eval()`) by default. + + If you get the error message below, you need to finetune the weights for your downstream task: + + ``` + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights + saved using + [`~DiffusionPipeline.save_pretrained`]. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + custom_revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id similar to + `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a + custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn’t need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if device_map contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline + class). The overwritten components are passed directly to the pipelines `__init__` method. See example + below for more information. + variant (`str`, *optional*): + Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + + + + To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with + `huggingface-cli login`. + + + + Examples: + + ```py + >>> from diffusers import AutoPipelineForTextToImage + + >>> pipeline = AutoPipelineForTextToImage.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> print(pipeline.__class__) + ``` + """ + config = cls.load_config(pretrained_model_or_path) + orig_class_name = config["_class_name"] + + if "controlnet" in kwargs: + orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") + + text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name) + + return text_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs) + + @classmethod + def from_pipe(cls, pipeline, **kwargs): + r""" + Instantiates a text-to-image Pytorch diffusion pipeline from another instantiated diffusion pipeline class. + + The from_pipe() method takes care of returning the correct pipeline class instance by finding the text-to-image + pipeline linked to the pipeline class using pattern matching on pipeline class name. + + All the modules the pipeline contains will be used to initialize the new pipeline without reallocating + additional memoery. + + The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pipeline (`DiffusionPipeline`): + an instantiated `DiffusionPipeline` object + + ```py + >>> from diffusers import AutoPipelineForTextToImage, AutoPipelineForImageToImage + + >>> pipe_i2i = AutoPipelineForImage2Image.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", requires_safety_checker=False + ... ) + + >>> pipe_t2i = AutoPipelineForTextToImage.from_pipe(pipe_t2i) + ``` + """ + + original_config = dict(pipeline.config) + original_cls_name = pipeline.__class__.__name__ + + # derive the pipeline class to instantiate + text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, original_cls_name) + + # define expected module and optional kwargs given the pipeline signature + expected_modules, optional_kwargs = _get_signature_keys(text_2_image_cls) + + pretrained_model_name_or_path = original_config.pop("_name_or_path", None) + + # allow users pass modules in `kwargs` to override the original pipeline's components + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + original_class_obj = { + k: pipeline.components[k] + for k, v in pipeline.components.items() + if k in expected_modules and k not in passed_class_obj + } + + # allow users pass optional kwargs to override the original pipelines config attribute + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} + original_pipe_kwargs = { + k: original_config[k] + for k, v in original_config.items() + if k in optional_kwargs and k not in passed_pipe_kwargs + } + + # config that were not expected by original pipeline is stored as private attribute + # we will pass them as optional arguments if they can be accepted by the pipeline + additional_pipe_kwargs = [ + k[1:] + for k in original_config.keys() + if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs + ] + for k in additional_pipe_kwargs: + original_pipe_kwargs[k] = original_config.pop(f"_{k}") + + text_2_image_kwargs = {**passed_class_obj, **original_class_obj, **passed_pipe_kwargs, **original_pipe_kwargs} + + # store unused config as private attribute + unused_original_config = { + f"{'' if k.startswith('_') else '_'}{k}": original_config[k] + for k, v in original_config.items() + if k not in text_2_image_kwargs + } + + missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(text_2_image_kwargs.keys()) + + if len(missing_modules) > 0: + raise ValueError( + f"Pipeline {text_2_image_cls} expected {expected_modules}, but only {set(passed_class_obj.keys()) + set(original_class_obj.keys())} were passed" + ) + + model = text_2_image_cls(**text_2_image_kwargs) + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + model.register_to_config(**unused_original_config) + + return model + + +class AutoPipelineForImage2Image(ConfigMixin): + r""" + + AutoPipeline for image-to-image generation. + + [`AutoPipelineForImage2Image`] is a generic pipeline class that will be instantiated as one of the image-to-image + pipeline classes in diffusers. + + The pipeline type (for example [`StableDiffusionImg2ImgPipeline`]) is automatically selected when created with the + `AutoPipelineForImage2Image.from_pretrained(pretrained_model_name_or_path)` or + `AutoPipelineForImage2Image.from_pipe(pipeline)` class methods. + + This class cannot be instantiated using __init__() (throws an error). + + Class attributes: + + - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the + diffusion pipeline's components. + + """ + config_name = "model_index.json" + + def __init__(self, *args, **kwargs): + raise EnvironmentError( + f"{self.__class__.__name__} is designed to be instantiated " + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " + f"`{self.__class__.__name__}.from_pipe(pipeline)` methods." + ) + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, **kwargs): + r""" + Instantiates a image-to-image Pytorch diffusion pipeline from pretrained pipeline weight. + + The from_pretrained() method takes care of returning the correct pipeline class instance by: + 1. Detect the pipeline class of the pretrained_model_or_path based on the _class_name property of its + config object + 2. Find the image-to-image pipeline linked to the pipeline class using pattern matching on pipeline class + name. + + If a `controlnet` argument is passed, it will instantiate a StableDiffusionControlNetImg2ImgPipeline object. + + The pipeline is set in evaluation mode (`model.eval()`) by default. + + If you get the error message below, you need to finetune the weights for your downstream task: + + ``` + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights + saved using + [`~DiffusionPipeline.save_pretrained`]. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + custom_revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id similar to + `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a + custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn’t need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if device_map contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline + class). The overwritten components are passed directly to the pipelines `__init__` method. See example + below for more information. + variant (`str`, *optional*): + Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + + + + To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with + `huggingface-cli login`. + + + + Examples: + + ```py + >>> from diffusers import AutoPipelineForTextToImage + + >>> pipeline = AutoPipelineForImageToImage.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> print(pipeline.__class__) + ``` + """ + config = cls.load_config(pretrained_model_or_path) + orig_class_name = config["_class_name"] + + if "controlnet" in kwargs: + orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") + + image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name) + + return image_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs) + + @classmethod + def from_pipe(cls, pipeline, **kwargs): + r""" + Instantiates a image-to-image Pytorch diffusion pipeline from another instantiated diffusion pipeline class. + + The from_pipe() method takes care of returning the correct pipeline class instance by finding the + image-to-image pipeline linked to the pipeline class using pattern matching on pipeline class name. + + All the modules the pipeline contains will be used to initialize the new pipeline without reallocating + additional memoery. + + The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pipeline (`DiffusionPipeline`): + an instantiated `DiffusionPipeline` object + + Examples: + + ```py + >>> from diffusers import AutoPipelineForTextToImage, AutoPipelineForImageToImage + + >>> pipe_t2i = AutoPipelineForText2Image.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", requires_safety_checker=False + ... ) + + >>> pipe_i2i = AutoPipelineForImageToImage.from_pipe(pipe_t2i) + ``` + """ + + original_config = dict(pipeline.config) + original_cls_name = pipeline.__class__.__name__ + + # derive the pipeline class to instantiate + image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, original_cls_name) + + # define expected module and optional kwargs given the pipeline signature + expected_modules, optional_kwargs = _get_signature_keys(image_2_image_cls) + + pretrained_model_name_or_path = original_config.pop("_name_or_path", None) + + # allow users pass modules in `kwargs` to override the original pipeline's components + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + original_class_obj = { + k: pipeline.components[k] + for k, v in pipeline.components.items() + if k in expected_modules and k not in passed_class_obj + } + + # allow users pass optional kwargs to override the original pipelines config attribute + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} + original_pipe_kwargs = { + k: original_config[k] + for k, v in original_config.items() + if k in optional_kwargs and k not in passed_pipe_kwargs + } + + # config attribute that were not expected by original pipeline is stored as its private attribute + # we will pass them as optional arguments if they can be accepted by the pipeline + additional_pipe_kwargs = [ + k[1:] + for k in original_config.keys() + if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs + ] + for k in additional_pipe_kwargs: + original_pipe_kwargs[k] = original_config.pop(f"_{k}") + + image_2_image_kwargs = {**passed_class_obj, **original_class_obj, **passed_pipe_kwargs, **original_pipe_kwargs} + + # store unused config as private attribute + unused_original_config = { + f"{'' if k.startswith('_') else '_'}{k}": original_config[k] + for k, v in original_config.items() + if k not in image_2_image_kwargs + } + + missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(image_2_image_kwargs.keys()) + + if len(missing_modules) > 0: + raise ValueError( + f"Pipeline {image_2_image_cls} expected {expected_modules}, but only {set(passed_class_obj.keys()) + set(original_class_obj.keys())} were passed" + ) + + model = image_2_image_cls(**image_2_image_kwargs) + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + model.register_to_config(**unused_original_config) + + return model + + +class AutoPipelineForInpainting(ConfigMixin): + r""" + + AutoPipeline for inpainting generation. + + [`AutoPipelineForInpainting`] is a generic pipeline class that will be instantiated as one of the inpainting + pipeline class in diffusers. + + The pipeline type (for example [`IFInpaintingPipeline`]) is automatically selected when created with the + AutoPipelineForInpainting.from_pretrained(pretrained_model_name_or_path) or + AutoPipelineForInpainting.from_pipe(pipeline) class methods . + + This class cannot be instantiated using __init__() (throws an error). + + Class attributes: + + - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the + diffusion pipeline's components. + + """ + config_name = "model_index.json" + + def __init__(self, *args, **kwargs): + raise EnvironmentError( + f"{self.__class__.__name__} is designed to be instantiated " + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " + f"`{self.__class__.__name__}.from_pipe(pipeline)` methods." + ) + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, **kwargs): + r""" + Instantiates a inpainting Pytorch diffusion pipeline from pretrained pipeline weight. + + The from_pretrained() method takes care of returning the correct pipeline class instance by: + 1. Detect the pipeline class of the pretrained_model_or_path based on the _class_name property of its + config object + 2. Find the inpainting pipeline linked to the pipeline class using pattern matching on pipeline class name. + + If a `controlnet` argument is passed, it will instantiate a StableDiffusionControlNetInpaintPipeline object. + + The pipeline is set in evaluation mode (`model.eval()`) by default. + + If you get the error message below, you need to finetune the weights for your downstream task: + + ``` + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights + saved using + [`~DiffusionPipeline.save_pretrained`]. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + custom_revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id similar to + `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a + custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn’t need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if device_map contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline + class). The overwritten components are passed directly to the pipelines `__init__` method. See example + below for more information. + variant (`str`, *optional*): + Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + + + + To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with + `huggingface-cli login`. + + + + Examples: + + ```py + >>> from diffusers import AutoPipelineForTextToImage + + >>> pipeline = AutoPipelineForImageToImage.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> print(pipeline.__class__) + ``` + """ + config = cls.load_config(pretrained_model_or_path) + orig_class_name = config["_class_name"] + + if "controlnet" in kwargs: + orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") + + inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name) + + return inpainting_cls.from_pretrained(pretrained_model_or_path, **kwargs) + + @classmethod + def from_pipe(cls, pipeline, **kwargs): + r""" + Instantiates a inpainting Pytorch diffusion pipeline from another instantiated diffusion pipeline class. + + The from_pipe() method takes care of returning the correct pipeline class instance by finding the inpainting + pipeline linked to the pipeline class using pattern matching on pipeline class name. + + All the modules the pipeline class contain will be used to initialize the new pipeline without reallocating + additional memoery. + + The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pipeline (`DiffusionPipeline`): + an instantiated `DiffusionPipeline` object + + Examples: + + ```py + >>> from diffusers import AutoPipelineForTextToImage, AutoPipelineForInpainting + + >>> pipe_t2i = AutoPipelineForText2Image.from_pretrained( + ... "DeepFloyd/IF-I-XL-v1.0", requires_safety_checker=False + ... ) + + >>> pipe_inpaint = AutoPipelineForInpainting.from_pipe(pipe_t2i) + ``` + """ + original_config = dict(pipeline.config) + original_cls_name = pipeline.__class__.__name__ + + # derive the pipeline class to instantiate + inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, original_cls_name) + + # define expected module and optional kwargs given the pipeline signature + expected_modules, optional_kwargs = _get_signature_keys(inpainting_cls) + + pretrained_model_name_or_path = original_config.pop("_name_or_path", None) + + # allow users pass modules in `kwargs` to override the original pipeline's components + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + original_class_obj = { + k: pipeline.components[k] + for k, v in pipeline.components.items() + if k in expected_modules and k not in passed_class_obj + } + + # allow users pass optional kwargs to override the original pipelines config attribute + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} + original_pipe_kwargs = { + k: original_config[k] + for k, v in original_config.items() + if k in optional_kwargs and k not in passed_pipe_kwargs + } + + # config that were not expected by original pipeline is stored as private attribute + # we will pass them as optional arguments if they can be accepted by the pipeline + additional_pipe_kwargs = [ + k[1:] + for k in original_config.keys() + if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs + ] + for k in additional_pipe_kwargs: + original_pipe_kwargs[k] = original_config.pop(f"_{k}") + + inpainting_kwargs = {**passed_class_obj, **original_class_obj, **passed_pipe_kwargs, **original_pipe_kwargs} + + # store unused config as private attribute + unused_original_config = { + f"{'' if k.startswith('_') else '_'}{k}": original_config[k] + for k, v in original_config.items() + if k not in inpainting_kwargs + } + + missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(inpainting_kwargs.keys()) + + if len(missing_modules) > 0: + raise ValueError( + f"Pipeline {inpainting_cls} expected {expected_modules}, but only {set(passed_class_obj.keys()) + set(original_class_obj.keys())} were passed" + ) + + model = inpainting_cls(**inpainting_kwargs) + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + model.register_to_config(**unused_original_config) + + return model diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py index 31edb915e8f7..dcb6d07939b4 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py @@ -20,8 +20,6 @@ ) from ...models import UNet2DConditionModel, VQModel -from ...pipelines import DiffusionPipeline -from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDIMScheduler, DDPMScheduler from ...utils import ( is_accelerate_available, @@ -30,6 +28,7 @@ randn_tensor, replace_example_docstring, ) +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .text_encoder import MultilingualCLIP diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py index e3e4b04febac..2cec7c4d663c 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py @@ -23,8 +23,6 @@ ) from ...models import UNet2DConditionModel, VQModel -from ...pipelines import DiffusionPipeline -from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDIMScheduler from ...utils import ( is_accelerate_available, @@ -33,6 +31,7 @@ randn_tensor, replace_example_docstring, ) +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .text_encoder import MultilingualCLIP diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py index a4a7dc225947..eabbbb27ea5e 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py @@ -25,8 +25,6 @@ ) from ...models import UNet2DConditionModel, VQModel -from ...pipelines import DiffusionPipeline -from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDIMScheduler from ...utils import ( is_accelerate_available, @@ -35,6 +33,7 @@ randn_tensor, replace_example_docstring, ) +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .text_encoder import MultilingualCLIP diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py index 4d99180b3d11..bf75eeacfdf3 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py @@ -21,7 +21,6 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from ...models import PriorTransformer -from ...pipelines import DiffusionPipeline from ...schedulers import UnCLIPScheduler from ...utils import ( BaseOutput, @@ -29,6 +28,7 @@ randn_tensor, replace_example_docstring, ) +from ..pipeline_utils import DiffusionPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py index 6db5260b04d8..007700baf627 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py @@ -17,8 +17,6 @@ import torch from ...models import UNet2DConditionModel, VQModel -from ...pipelines import DiffusionPipeline -from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDPMScheduler from ...utils import ( is_accelerate_available, @@ -27,6 +25,7 @@ randn_tensor, replace_example_docstring, ) +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py index d6fc26911747..6d616f8c760c 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py @@ -17,8 +17,6 @@ import torch from ...models import UNet2DConditionModel, VQModel -from ...pipelines import DiffusionPipeline -from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDPMScheduler from ...utils import ( is_accelerate_available, @@ -27,6 +25,7 @@ randn_tensor, replace_example_docstring, ) +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py index 8a25624b7267..fbfcf91569ef 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py @@ -20,8 +20,6 @@ from PIL import Image from ...models import UNet2DConditionModel, VQModel -from ...pipelines import DiffusionPipeline -from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDPMScheduler from ...utils import ( is_accelerate_available, @@ -30,6 +28,7 @@ randn_tensor, replace_example_docstring, ) +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py index 26976ad0c925..9ecea9b8393d 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py @@ -20,8 +20,6 @@ from PIL import Image from ...models import UNet2DConditionModel, VQModel -from ...pipelines import DiffusionPipeline -from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDPMScheduler from ...utils import ( is_accelerate_available, @@ -30,6 +28,7 @@ randn_tensor, replace_example_docstring, ) +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py index 27bd01984dbf..f4e6349665e3 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py @@ -22,8 +22,6 @@ from PIL import Image from ...models import UNet2DConditionModel, VQModel -from ...pipelines import DiffusionPipeline -from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDPMScheduler from ...utils import ( is_accelerate_available, @@ -32,6 +30,7 @@ randn_tensor, replace_example_docstring, ) +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py index 1bfc6523cdf9..b6ab2ca3fc23 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py @@ -5,7 +5,6 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from ...models import PriorTransformer -from ...pipelines import DiffusionPipeline from ...schedulers import UnCLIPScheduler from ...utils import ( logging, @@ -13,6 +12,7 @@ replace_example_docstring, ) from ..kandinsky import KandinskyPriorPipelineOutput +from ..pipeline_utils import DiffusionPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py index bcbeb78bac10..75be6e54c93f 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py @@ -5,7 +5,6 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from ...models import PriorTransformer -from ...pipelines import DiffusionPipeline from ...schedulers import UnCLIPScheduler from ...utils import ( logging, @@ -13,6 +12,7 @@ replace_example_docstring, ) from ..kandinsky import KandinskyPriorPipelineOutput +from ..pipeline_utils import DiffusionPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py index fdcbe550860a..a9fab3f3a979 100644 --- a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py +++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py @@ -22,7 +22,6 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer from ...models import PriorTransformer -from ...pipelines import DiffusionPipeline from ...schedulers import HeunDiscreteScheduler from ...utils import ( BaseOutput, @@ -32,6 +31,7 @@ randn_tensor, replace_example_docstring, ) +from ..pipeline_utils import DiffusionPipeline from .renderer import ShapERenderer diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py index 08c585c5ad73..8693641d5cef 100644 --- a/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py @@ -21,7 +21,6 @@ from transformers import CLIPImageProcessor, CLIPVisionModel from ...models import PriorTransformer -from ...pipelines import DiffusionPipeline from ...schedulers import HeunDiscreteScheduler from ...utils import ( BaseOutput, @@ -29,6 +28,7 @@ randn_tensor, replace_example_docstring, ) +from ..pipeline_utils import DiffusionPipeline from .renderer import ShapERenderer diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index 29a57470a341..16197def96ff 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -23,9 +23,9 @@ from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin -from ...pipelines import DiffusionPipeline from ...schedulers import LMSDiscreteScheduler from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py index f3b67bebfc8d..1d400a23c972 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py @@ -21,10 +21,9 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel -from ...pipelines import DiffusionPipeline -from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import UnCLIPScheduler from ...utils import logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .text_proj import UnCLIPTextProjModel diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py index 580417f517be..e24c6af15a01 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py @@ -26,9 +26,9 @@ ) from ...models import UNet2DConditionModel, UNet2DModel -from ...pipelines import DiffusionPipeline, ImagePipelineOutput from ...schedulers import UnCLIPScheduler from ...utils import logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .text_proj import UnCLIPTextProjModel diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index b955ec5320de..a0e3bee76aa9 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -240,6 +240,51 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoPipelineForImage2Image(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AutoPipelineForInpainting(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AutoPipelineForText2Image(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ConsistencyModelPipeline(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/test_pipelines_auto.py b/tests/pipelines/test_pipelines_auto.py new file mode 100644 index 000000000000..595a7a5f25ff --- /dev/null +++ b/tests/pipelines/test_pipelines_auto.py @@ -0,0 +1,201 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest +from collections import OrderedDict + +import torch + +from diffusers import ( + AutoPipelineForImage2Image, + AutoPipelineForInpainting, + AutoPipelineForText2Image, + ControlNetModel, +) +from diffusers.pipelines.auto_pipeline import ( + AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, + AUTO_INPAINT_PIPELINES_MAPPING, + AUTO_TEXT2IMAGE_PIPELINES_MAPPING, +) +from diffusers.utils import slow + + +PRETRAINED_MODEL_REPO_MAPPING = OrderedDict( + [ + ("stable-diffusion", "runwayml/stable-diffusion-v1-5"), + ("if", "DeepFloyd/IF-I-XL-v1.0"), + ("kandinsky", "kandinsky-community/kandinsky-2-1"), + ("kandinsky22", "kandinsky-community/kandinsky-2-2-decoder"), + ] +) + + +class AutoPipelineFastTest(unittest.TestCase): + def test_from_pipe_consistent(self): + pipe = AutoPipelineForText2Image.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", requires_safety_checker=False + ) + original_config = dict(pipe.config) + + pipe = AutoPipelineForImage2Image.from_pipe(pipe) + assert dict(pipe.config) == original_config + + pipe = AutoPipelineForText2Image.from_pipe(pipe) + assert dict(pipe.config) == original_config + + def test_from_pipe_override(self): + pipe = AutoPipelineForText2Image.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", requires_safety_checker=False + ) + + pipe = AutoPipelineForImage2Image.from_pipe(pipe, requires_safety_checker=True) + assert pipe.config.requires_safety_checker is True + + pipe = AutoPipelineForText2Image.from_pipe(pipe, requires_safety_checker=True) + assert pipe.config.requires_safety_checker is True + + def test_from_pipe_consistent_sdxl(self): + pipe = AutoPipelineForImage2Image.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-xl-pipe", + requires_aesthetics_score=True, + force_zeros_for_empty_prompt=False, + ) + + original_config = dict(pipe.config) + + pipe = AutoPipelineForText2Image.from_pipe(pipe) + pipe = AutoPipelineForImage2Image.from_pipe(pipe) + + assert dict(pipe.config) == original_config + + +@slow +class AutoPipelineIntegrationTest(unittest.TestCase): + def test_pipe_auto(self): + for model_name, model_repo in PRETRAINED_MODEL_REPO_MAPPING.items(): + # test txt2img + pipe_txt2img = AutoPipelineForText2Image.from_pretrained( + model_repo, variant="fp16", torch_dtype=torch.float16 + ) + self.assertIsInstance(pipe_txt2img, AUTO_TEXT2IMAGE_PIPELINES_MAPPING[model_name]) + + pipe_to = AutoPipelineForText2Image.from_pipe(pipe_txt2img) + self.assertIsInstance(pipe_to, AUTO_TEXT2IMAGE_PIPELINES_MAPPING[model_name]) + + pipe_to = AutoPipelineForImage2Image.from_pipe(pipe_txt2img) + self.assertIsInstance(pipe_to, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING[model_name]) + + if "kandinsky" not in model_name: + pipe_to = AutoPipelineForInpainting.from_pipe(pipe_txt2img) + self.assertIsInstance(pipe_to, AUTO_INPAINT_PIPELINES_MAPPING[model_name]) + + del pipe_txt2img, pipe_to + gc.collect() + + # test img2img + + pipe_img2img = AutoPipelineForImage2Image.from_pretrained( + model_repo, variant="fp16", torch_dtype=torch.float16 + ) + self.assertIsInstance(pipe_img2img, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING[model_name]) + + pipe_to = AutoPipelineForText2Image.from_pipe(pipe_img2img) + self.assertIsInstance(pipe_to, AUTO_TEXT2IMAGE_PIPELINES_MAPPING[model_name]) + + pipe_to = AutoPipelineForImage2Image.from_pipe(pipe_img2img) + self.assertIsInstance(pipe_to, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING[model_name]) + + if "kandinsky" not in model_name: + pipe_to = AutoPipelineForInpainting.from_pipe(pipe_img2img) + self.assertIsInstance(pipe_to, AUTO_INPAINT_PIPELINES_MAPPING[model_name]) + + del pipe_img2img, pipe_to + gc.collect() + + # test inpaint + + if "kandinsky" not in model_name: + pipe_inpaint = AutoPipelineForInpainting.from_pretrained( + model_repo, variant="fp16", torch_dtype=torch.float16 + ) + self.assertIsInstance(pipe_inpaint, AUTO_INPAINT_PIPELINES_MAPPING[model_name]) + + pipe_to = AutoPipelineForText2Image.from_pipe(pipe_inpaint) + self.assertIsInstance(pipe_to, AUTO_TEXT2IMAGE_PIPELINES_MAPPING[model_name]) + + pipe_to = AutoPipelineForImage2Image.from_pipe(pipe_inpaint) + self.assertIsInstance(pipe_to, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING[model_name]) + + pipe_to = AutoPipelineForInpainting.from_pipe(pipe_inpaint) + self.assertIsInstance(pipe_to, AUTO_INPAINT_PIPELINES_MAPPING[model_name]) + + del pipe_inpaint, pipe_to + gc.collect() + + def test_from_pipe_consistent(self): + for model_name, model_repo in PRETRAINED_MODEL_REPO_MAPPING.items(): + if model_name in ["kandinsky", "kandinsky22"]: + auto_pipes = [AutoPipelineForText2Image, AutoPipelineForImage2Image] + else: + auto_pipes = [AutoPipelineForText2Image, AutoPipelineForImage2Image, AutoPipelineForInpainting] + + # test from_pretrained + for pipe_from_class in auto_pipes: + pipe_from = pipe_from_class.from_pretrained(model_repo, variant="fp16", torch_dtype=torch.float16) + pipe_from_config = dict(pipe_from.config) + + for pipe_to_class in auto_pipes: + pipe_to = pipe_to_class.from_pipe(pipe_from) + self.assertEqual(dict(pipe_to.config), pipe_from_config) + + del pipe_from, pipe_to + gc.collect() + + def test_controlnet(self): + # test from_pretrained + model_repo = "runwayml/stable-diffusion-v1-5" + controlnet_repo = "lllyasviel/sd-controlnet-canny" + + controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=torch.float16) + + pipe_txt2img = AutoPipelineForText2Image.from_pretrained( + model_repo, controlnet=controlnet, torch_dtype=torch.float16 + ) + self.assertIsInstance(pipe_txt2img, AUTO_TEXT2IMAGE_PIPELINES_MAPPING["stable-diffusion-controlnet"]) + + pipe_img2img = AutoPipelineForImage2Image.from_pretrained( + model_repo, controlnet=controlnet, torch_dtype=torch.float16 + ) + self.assertIsInstance(pipe_img2img, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["stable-diffusion-controlnet"]) + + pipe_inpaint = AutoPipelineForInpainting.from_pretrained( + model_repo, controlnet=controlnet, torch_dtype=torch.float16 + ) + self.assertIsInstance(pipe_inpaint, AUTO_INPAINT_PIPELINES_MAPPING["stable-diffusion-controlnet"]) + + # test from_pipe + for pipe_from in [pipe_txt2img, pipe_img2img, pipe_inpaint]: + pipe_to = AutoPipelineForText2Image.from_pipe(pipe_from) + self.assertIsInstance(pipe_to, AUTO_TEXT2IMAGE_PIPELINES_MAPPING["stable-diffusion-controlnet"]) + self.assertEqual(dict(pipe_to.config), dict(pipe_txt2img.config)) + + pipe_to = AutoPipelineForImage2Image.from_pipe(pipe_from) + self.assertIsInstance(pipe_to, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["stable-diffusion-controlnet"]) + self.assertEqual(dict(pipe_to.config), dict(pipe_img2img.config)) + + pipe_to = AutoPipelineForInpainting.from_pipe(pipe_from) + self.assertIsInstance(pipe_to, AUTO_INPAINT_PIPELINES_MAPPING["stable-diffusion-controlnet"]) + self.assertEqual(dict(pipe_to.config), dict(pipe_inpaint.config))