diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 39178edc00d1..2a6d40d2e91b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -59,6 +59,9 @@ ) from .pipelines import ( AudioPipelineOutput, + AutoPipelineForImageToImage, + AutoPipelineForInpainting, + AutoPipelineForTextToImage, ConsistencyModelPipeline, DanceDiffusionPipeline, DDIMPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 937ac1b5e3d7..b0cdde30e780 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -24,7 +24,14 @@ from .dit import DiTPipeline from .latent_diffusion import LDMSuperResolutionPipeline from .latent_diffusion_uncond import LDMPipeline - from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput + from .pipeline_utils import ( + AudioPipelineOutput, + AutoPipelineForImageToImage, + AutoPipelineForInpainting, + AutoPipelineForTextToImage, + DiffusionPipeline, + ImagePipelineOutput, + ) from .pndm import PNDMPipeline from .repaint import RePaintPipeline from .score_sde_ve import ScoreSdeVePipeline diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index c2b7399be080..25996248aee2 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -336,7 +336,7 @@ def _get_pipeline_class(class_obj, config, custom_pipeline=None, cache_dir=None, custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision ) - if class_obj != DiffusionPipeline: + if class_obj != DiffusionPipeline and not class_obj.__name__.startswith("Auto"): return class_obj diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) @@ -474,6 +474,7 @@ class DiffusionPipeline(ConfigMixin): """ config_name = "model_index.json" _optional_components = [] + _linked_pipelines = [] def register_modules(self, **kwargs): # import it here to avoid circular import @@ -924,6 +925,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision ) + if cls.__name__.startswith("Auto"): + pipeline_class_linked = pipeline_class._get_linked_pipelines(cls.task) + if len(pipeline_class_linked) == 0: + raise ValueError(f"can't find a pipeline with task {cls.task} for pipeline class {pipeline_class}") + else: + pipeline_class = pipeline_class_linked[0] + # DEPRECATED: To be removed in 1.0.0 if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( version.parse(config_dict["_diffusers_version"]).base_version @@ -1509,3 +1517,39 @@ def set_attention_slice(self, slice_size: Optional[int]): for module in modules: module.set_attention_slice(slice_size) + + @classmethod + def _get_linked_pipelines(cls, task): + linked_classes_str = list(set([cls.__name__] + cls._linked_pipelines)) + diffusers_library = importlib.import_module(__name__.split(".")[0]) + linked_classes = [getattr(diffusers_library, c) for c in linked_classes_str if hasattr(diffusers_library, c)] + linked_classes = [c for c in linked_classes if c.task == task] + return linked_classes + + @classmethod + def from_pipe(cls, pipeline): + pipeline_class = pipeline.__class__ + + if cls.__name__.startswith("Auto"): + pipeline_class_linked = pipeline_class._get_linked_pipelines(cls.task) + if len(pipeline_class_linked) == 0: + raise ValueError( + f"can't find linked pipeline with task {cls.task} for pipeline class {pipeline_class}" + ) + else: + pipeline_class = pipeline_class_linked[0] + + model = pipeline_class(**pipeline.components) + return model + + +class AutoPipelineForTextToImage(DiffusionPipeline): + task = "TextToImage" + + +class AutoPipelineForImageToImage(DiffusionPipeline): + task = "ImageToImage" + + +class AutoPipelineForInpainting(DiffusionPipeline): + task = "Inpaint" diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 33ab05a1dacb..95a7663bcc6c 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from enum import Enum from typing import List, Optional, Union import numpy as np @@ -36,6 +37,12 @@ class StableDiffusionPipelineOutput(BaseOutput): nsfw_content_detected: Optional[List[bool]] +class StableDiffusionPipelines(Enum): + StableDiffusionPipeline = 1 + StableDiffusionImg2ImgPipeline = 2 + StableDiffusionInpaintPipeline = 3 + + try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 9ad4d404fdbe..9743e08e2f90 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -34,7 +34,7 @@ replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput +from . import StableDiffusionPipelineOutput, StableDiffusionPipelines from .safety_checker import StableDiffusionSafetyChecker @@ -105,6 +105,8 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] + _linked_pipelines = [e.name for e in StableDiffusionPipelines] + task = "TextToImage" def __init__( self, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index f8874ba2cfae..0f18c8c62813 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -37,7 +37,7 @@ replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput +from . import StableDiffusionPipelineOutput, StableDiffusionPipelines from .safety_checker import StableDiffusionSafetyChecker @@ -136,6 +136,8 @@ class StableDiffusionImg2ImgPipeline( Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] + _linked_pipelines = [e.name for e in StableDiffusionPipelines] + task = "ImageToImage" def __init__( self, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index b012a79ba855..c4f866230d9b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -29,7 +29,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput +from . import StableDiffusionPipelineOutput, StableDiffusionPipelines from .safety_checker import StableDiffusionSafetyChecker @@ -200,6 +200,8 @@ class StableDiffusionInpaintPipeline( Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] + _linked_pipelines = [e.name for e in StableDiffusionPipelines] + task = "Inpaint" def __init__( self,