From 26f04538e3fb9dc18bcc485d85be89f98caa7b19 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 17 Jul 2023 10:34:39 +0000 Subject: [PATCH 1/5] initial --- src/diffusers/__init__.py | 3 +++ src/diffusers/pipelines/__init__.py | 2 +- src/diffusers/pipelines/pipeline_utils.py | 26 ++++++++++++++++++- .../pipelines/stable_diffusion/__init__.py | 7 ++++- .../pipeline_stable_diffusion.py | 4 ++- .../pipeline_stable_diffusion_img2img.py | 4 ++- .../pipeline_stable_diffusion_inpaint.py | 4 ++- 7 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 39178edc00d1..4f2b944995df 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -72,6 +72,9 @@ PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline, + AutoPipelineForTextToImage, + AutoPipelineForImageToImage, + AutoPipelineForInpainting, ) from .schedulers import ( CMStochasticIterativeScheduler, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 937ac1b5e3d7..52466a6e3121 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -24,7 +24,7 @@ from .dit import DiTPipeline from .latent_diffusion import LDMSuperResolutionPipeline from .latent_diffusion_uncond import LDMPipeline - from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput + from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput, AutoPipelineForTextToImage, AutoPipelineForImageToImage, AutoPipelineForInpainting from .pndm import PNDMPipeline from .repaint import RePaintPipeline from .score_sde_ve import ScoreSdeVePipeline diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index c2b7399be080..90906845e78d 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -336,7 +336,7 @@ def _get_pipeline_class(class_obj, config, custom_pipeline=None, cache_dir=None, custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision ) - if class_obj != DiffusionPipeline: + if class_obj != DiffusionPipeline and not class_obj.__name__.startswith('Auto') : return class_obj diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) @@ -474,6 +474,7 @@ class DiffusionPipeline(ConfigMixin): """ config_name = "model_index.json" _optional_components = [] + _linked_pipelines = [] def register_modules(self, **kwargs): # import it here to avoid circular import @@ -924,6 +925,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision ) + if hasattr(cls, "task"): + pipeline_class = pipeline_class._get_linked_pipelines(cls.task)[0] + # DEPRECATED: To be removed in 1.0.0 if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( version.parse(config_dict["_diffusers_version"]).base_version @@ -1509,3 +1513,23 @@ def set_attention_slice(self, slice_size: Optional[int]): for module in modules: module.set_attention_slice(slice_size) + + @classmethod + def _get_linked_pipelines(cls, task): + linked_classes_str = list(set([cls.__name__] + cls._linked_pipelines)) + diffusers_library = importlib.import_module(__name__.split(".")[0]) + linked_classes = [ + getattr(diffusers_library, c) for c in linked_classes_str if hasattr(diffusers_library, c) + ] + linked_classes = [c for c in linked_classes if c.task == task] + return linked_classes + + +class AutoPipelineForTextToImage(DiffusionPipeline): + task = "TextToImage" + +class AutoPipelineForImageToImage(DiffusionPipeline): + task = "ImageToImage" + +class AutoPipelineForInpainting(DiffusionPipeline): + task = "Inpaint" \ No newline at end of file diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 33ab05a1dacb..4fd5a4b9de1d 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -4,6 +4,7 @@ import numpy as np import PIL from PIL import Image +from enum import Enum from ...utils import ( BaseOutput, @@ -35,6 +36,10 @@ class StableDiffusionPipelineOutput(BaseOutput): images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] +class StableDiffusionPipelines(Enum): + StableDiffusionPipeline = 1 + StableDiffusionImg2ImgPipeline = 2 + StableDiffusionInpaintPipeline = 3 try: if not (is_transformers_available() and is_torch_available()): @@ -133,4 +138,4 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput): from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline - from .safety_checker_flax import FlaxStableDiffusionSafetyChecker + from .safety_checker_flax import FlaxStableDiffusionSafetyChecker \ No newline at end of file diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 9ad4d404fdbe..9743e08e2f90 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -34,7 +34,7 @@ replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput +from . import StableDiffusionPipelineOutput, StableDiffusionPipelines from .safety_checker import StableDiffusionSafetyChecker @@ -105,6 +105,8 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] + _linked_pipelines = [e.name for e in StableDiffusionPipelines] + task = "TextToImage" def __init__( self, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index f8874ba2cfae..0f18c8c62813 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -37,7 +37,7 @@ replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput +from . import StableDiffusionPipelineOutput, StableDiffusionPipelines from .safety_checker import StableDiffusionSafetyChecker @@ -136,6 +136,8 @@ class StableDiffusionImg2ImgPipeline( Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] + _linked_pipelines = [e.name for e in StableDiffusionPipelines] + task = "ImageToImage" def __init__( self, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index b012a79ba855..c4f866230d9b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -29,7 +29,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput +from . import StableDiffusionPipelineOutput, StableDiffusionPipelines from .safety_checker import StableDiffusionSafetyChecker @@ -200,6 +200,8 @@ class StableDiffusionInpaintPipeline( Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] + _linked_pipelines = [e.name for e in StableDiffusionPipelines] + task = "Inpaint" def __init__( self, From b9285f79225cade81e3e61eda16f415d44758fe8 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 17 Jul 2023 10:38:54 +0000 Subject: [PATCH 2/5] make style --- src/diffusers/__init__.py | 6 +++--- src/diffusers/pipelines/__init__.py | 9 ++++++++- src/diffusers/pipelines/pipeline_utils.py | 12 ++++++------ src/diffusers/pipelines/stable_diffusion/__init__.py | 6 ++++-- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 4f2b944995df..2a6d40d2e91b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -59,6 +59,9 @@ ) from .pipelines import ( AudioPipelineOutput, + AutoPipelineForImageToImage, + AutoPipelineForInpainting, + AutoPipelineForTextToImage, ConsistencyModelPipeline, DanceDiffusionPipeline, DDIMPipeline, @@ -72,9 +75,6 @@ PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline, - AutoPipelineForTextToImage, - AutoPipelineForImageToImage, - AutoPipelineForInpainting, ) from .schedulers import ( CMStochasticIterativeScheduler, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 52466a6e3121..b0cdde30e780 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -24,7 +24,14 @@ from .dit import DiTPipeline from .latent_diffusion import LDMSuperResolutionPipeline from .latent_diffusion_uncond import LDMPipeline - from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput, AutoPipelineForTextToImage, AutoPipelineForImageToImage, AutoPipelineForInpainting + from .pipeline_utils import ( + AudioPipelineOutput, + AutoPipelineForImageToImage, + AutoPipelineForInpainting, + AutoPipelineForTextToImage, + DiffusionPipeline, + ImagePipelineOutput, + ) from .pndm import PNDMPipeline from .repaint import RePaintPipeline from .score_sde_ve import ScoreSdeVePipeline diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 90906845e78d..6c61bf41f651 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -336,7 +336,7 @@ def _get_pipeline_class(class_obj, config, custom_pipeline=None, cache_dir=None, custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision ) - if class_obj != DiffusionPipeline and not class_obj.__name__.startswith('Auto') : + if class_obj != DiffusionPipeline and not class_obj.__name__.startswith("Auto"): return class_obj diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) @@ -1513,14 +1513,12 @@ def set_attention_slice(self, slice_size: Optional[int]): for module in modules: module.set_attention_slice(slice_size) - + @classmethod def _get_linked_pipelines(cls, task): linked_classes_str = list(set([cls.__name__] + cls._linked_pipelines)) diffusers_library = importlib.import_module(__name__.split(".")[0]) - linked_classes = [ - getattr(diffusers_library, c) for c in linked_classes_str if hasattr(diffusers_library, c) - ] + linked_classes = [getattr(diffusers_library, c) for c in linked_classes_str if hasattr(diffusers_library, c)] linked_classes = [c for c in linked_classes if c.task == task] return linked_classes @@ -1528,8 +1526,10 @@ def _get_linked_pipelines(cls, task): class AutoPipelineForTextToImage(DiffusionPipeline): task = "TextToImage" + class AutoPipelineForImageToImage(DiffusionPipeline): task = "ImageToImage" + class AutoPipelineForInpainting(DiffusionPipeline): - task = "Inpaint" \ No newline at end of file + task = "Inpaint" diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 4fd5a4b9de1d..95a7663bcc6c 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -1,10 +1,10 @@ from dataclasses import dataclass +from enum import Enum from typing import List, Optional, Union import numpy as np import PIL from PIL import Image -from enum import Enum from ...utils import ( BaseOutput, @@ -36,11 +36,13 @@ class StableDiffusionPipelineOutput(BaseOutput): images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] + class StableDiffusionPipelines(Enum): StableDiffusionPipeline = 1 StableDiffusionImg2ImgPipeline = 2 StableDiffusionInpaintPipeline = 3 + try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() @@ -138,4 +140,4 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput): from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline - from .safety_checker_flax import FlaxStableDiffusionSafetyChecker \ No newline at end of file + from .safety_checker_flax import FlaxStableDiffusionSafetyChecker From a0c1aea2a9f4b0973f2a828273a56264eabd512d Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 17 Jul 2023 11:18:40 +0000 Subject: [PATCH 3/5] add from_pipe --- src/diffusers/pipelines/pipeline_utils.py | 42 ++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 6c61bf41f651..9095f733d081 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -926,7 +926,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) if hasattr(cls, "task"): - pipeline_class = pipeline_class._get_linked_pipelines(cls.task)[0] + pipeline_class_linked = pipeline_class._get_linked_pipelines(cls.task) + if len(pipeline_class_linked) == 0: + raise ValueError(f"can't find a pipeline with task {cls.task} for pipeline class {pipeline_class}") + else: + pipeline_class = pipeline_class_linked[0] # DEPRECATED: To be removed in 1.0.0 if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( @@ -1526,10 +1530,46 @@ def _get_linked_pipelines(cls, task): class AutoPipelineForTextToImage(DiffusionPipeline): task = "TextToImage" + @classmethod + def from_pipe(cls, pipeline): + pipeline_class = pipeline.__class__ + pipeline_class_linked = pipeline_class._get_linked_pipelines(cls.task) + if len(pipeline_class_linked) == 0: + raise ValueError(f"can't find linked pipeline with task {cls.task} for pipeline class {pipeline_class}") + else: + pipeline_class = pipeline_class_linked[0] + + model = pipeline_class(**pipeline.components) + return model + class AutoPipelineForImageToImage(DiffusionPipeline): task = "ImageToImage" + @classmethod + def from_pipe(cls, pipeline): + pipeline_class = pipeline.__class__ + pipeline_class_linked = pipeline_class._get_linked_pipelines(cls.task) + if len(pipeline_class_linked) == 0: + raise ValueError(f"can't find linked pipeline with task {cls.task} for pipeline class {pipeline_class}") + else: + pipeline_class = pipeline_class_linked[0] + + model = pipeline_class(**pipeline.components) + return model + class AutoPipelineForInpainting(DiffusionPipeline): task = "Inpaint" + + @classmethod + def from_pipe(cls, pipeline): + pipeline_class = pipeline.__class__ + pipeline_class_linked = pipeline_class._get_linked_pipelines(cls.task) + if len(pipeline_class_linked) == 0: + raise ValueError(f"can't find linked pipeline with task {cls.task} for pipeline class {pipeline_class}") + else: + pipeline_class = pipeline_class_linked[0] + + model = pipeline_class(**pipeline.components) + return model From e4800a5457968550680545e04731cf3a5eb1007e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 17 Jul 2023 11:36:12 +0000 Subject: [PATCH 4/5] fix --- src/diffusers/pipelines/pipeline_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 9095f733d081..4dc219113a0d 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -925,7 +925,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision ) - if hasattr(cls, "task"): + if cls.__name__.startswith("Auto") and hasattr(cls, "task"): pipeline_class_linked = pipeline_class._get_linked_pipelines(cls.task) if len(pipeline_class_linked) == 0: raise ValueError(f"can't find a pipeline with task {cls.task} for pipeline class {pipeline_class}") From bdca6cb7de36bfaa714d1d90a992746ae50ce455 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 17 Jul 2023 11:44:07 +0000 Subject: [PATCH 5/5] refactor --- src/diffusers/pipelines/pipeline_utils.py | 48 +++++++---------------- 1 file changed, 14 insertions(+), 34 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 4dc219113a0d..25996248aee2 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -925,7 +925,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision ) - if cls.__name__.startswith("Auto") and hasattr(cls, "task"): + if cls.__name__.startswith("Auto"): pipeline_class_linked = pipeline_class._get_linked_pipelines(cls.task) if len(pipeline_class_linked) == 0: raise ValueError(f"can't find a pipeline with task {cls.task} for pipeline class {pipeline_class}") @@ -1526,50 +1526,30 @@ def _get_linked_pipelines(cls, task): linked_classes = [c for c in linked_classes if c.task == task] return linked_classes - -class AutoPipelineForTextToImage(DiffusionPipeline): - task = "TextToImage" - @classmethod def from_pipe(cls, pipeline): pipeline_class = pipeline.__class__ - pipeline_class_linked = pipeline_class._get_linked_pipelines(cls.task) - if len(pipeline_class_linked) == 0: - raise ValueError(f"can't find linked pipeline with task {cls.task} for pipeline class {pipeline_class}") - else: - pipeline_class = pipeline_class_linked[0] + + if cls.__name__.startswith("Auto"): + pipeline_class_linked = pipeline_class._get_linked_pipelines(cls.task) + if len(pipeline_class_linked) == 0: + raise ValueError( + f"can't find linked pipeline with task {cls.task} for pipeline class {pipeline_class}" + ) + else: + pipeline_class = pipeline_class_linked[0] model = pipeline_class(**pipeline.components) return model -class AutoPipelineForImageToImage(DiffusionPipeline): - task = "ImageToImage" +class AutoPipelineForTextToImage(DiffusionPipeline): + task = "TextToImage" - @classmethod - def from_pipe(cls, pipeline): - pipeline_class = pipeline.__class__ - pipeline_class_linked = pipeline_class._get_linked_pipelines(cls.task) - if len(pipeline_class_linked) == 0: - raise ValueError(f"can't find linked pipeline with task {cls.task} for pipeline class {pipeline_class}") - else: - pipeline_class = pipeline_class_linked[0] - model = pipeline_class(**pipeline.components) - return model +class AutoPipelineForImageToImage(DiffusionPipeline): + task = "ImageToImage" class AutoPipelineForInpainting(DiffusionPipeline): task = "Inpaint" - - @classmethod - def from_pipe(cls, pipeline): - pipeline_class = pipeline.__class__ - pipeline_class_linked = pipeline_class._get_linked_pipelines(cls.task) - if len(pipeline_class_linked) == 0: - raise ValueError(f"can't find linked pipeline with task {cls.task} for pipeline class {pipeline_class}") - else: - pipeline_class = pipeline_class_linked[0] - - model = pipeline_class(**pipeline.components) - return model