From 5f7023a2e7e6a91df94d47a3915b9d04ecaf56a3 Mon Sep 17 00:00:00 2001 From: William Berman Date: Mon, 13 Feb 2023 22:17:37 -0800 Subject: [PATCH] karras diffusion schedulers Enum->Union --- docs/source/en/api/schedulers/overview.mdx | 4 --- src/diffusers/schedulers/__init__.py | 31 +++++++++++++++++-- src/diffusers/schedulers/scheduling_ddim.py | 4 +-- .../schedulers/scheduling_ddim_flax.py | 4 +-- src/diffusers/schedulers/scheduling_ddpm.py | 4 +-- .../schedulers/scheduling_ddpm_flax.py | 4 +-- .../schedulers/scheduling_deis_multistep.py | 4 +-- .../scheduling_dpmsolver_multistep.py | 4 +-- .../scheduling_dpmsolver_multistep_flax.py | 4 +-- .../scheduling_dpmsolver_singlestep.py | 4 +-- .../scheduling_euler_ancestral_discrete.py | 4 +-- .../schedulers/scheduling_euler_discrete.py | 4 +-- .../schedulers/scheduling_heun_discrete.py | 4 +-- .../scheduling_k_dpm_2_ancestral_discrete.py | 4 +-- .../schedulers/scheduling_k_dpm_2_discrete.py | 4 +-- .../schedulers/scheduling_lms_discrete.py | 4 +-- .../scheduling_lms_discrete_flax.py | 4 +-- src/diffusers/schedulers/scheduling_pndm.py | 4 +-- .../schedulers/scheduling_pndm_flax.py | 4 +-- src/diffusers/schedulers/scheduling_utils.py | 30 +++++++++--------- .../schedulers/scheduling_utils_flax.py | 15 ++++----- 21 files changed, 86 insertions(+), 62 deletions(-) diff --git a/docs/source/en/api/schedulers/overview.mdx b/docs/source/en/api/schedulers/overview.mdx index d27fbe10c528..1328cfd48397 100644 --- a/docs/source/en/api/schedulers/overview.mdx +++ b/docs/source/en/api/schedulers/overview.mdx @@ -80,7 +80,3 @@ The base class [`SchedulerMixin`] implements low level utilities used by multipl The class [`SchedulerOutput`] contains the outputs from any schedulers `step(...)` call. [[autodoc]] schedulers.scheduling_utils.SchedulerOutput - -### KarrasDiffusionSchedulers - -[[autodoc]] schedulers.scheduling_utils.KarrasDiffusionSchedulers diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 3746acd5b576..e482b3bd1dd6 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -13,6 +13,8 @@ # limitations under the License. +from typing import Union + from ..utils import OptionalDependencyNotAvailable, is_flax_available, is_scipy_available, is_torch_available @@ -39,7 +41,7 @@ from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler from .scheduling_unclip import UnCLIPScheduler - from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + from .scheduling_utils import SchedulerMixin from .scheduling_vq_diffusion import VQDiffusionScheduler try: @@ -56,7 +58,6 @@ from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler from .scheduling_utils_flax import ( - FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left, @@ -70,3 +71,29 @@ from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 else: from .scheduling_lms_discrete import LMSDiscreteScheduler + + +# NOTE keep in sync with ./scheduling_utils.py `karras_diffusion_scheduler_compatibles` +KarrasDiffusionSchedulers = Union[ + DDIMScheduler, + DDPMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + DEISMultistepScheduler, +] + +# NOTE keep in sync with ./scheduling_utils_flax.py `flax_karras_diffusion_scheduler_compatibles` +FlaxKarrasDiffusionSchedulers = Union[ + FlaxDDIMScheduler, + FlaxDDPMScheduler, + FlaxPNDMScheduler, + FlaxLMSDiscreteScheduler, + FlaxDPMSolverMultistepScheduler, +] diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 4eeb67f6b182..859fee2af670 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -24,7 +24,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, randn_tensor -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin +from .scheduling_utils import SchedulerMixin, karras_diffusion_scheduler_compatibles @dataclass @@ -112,7 +112,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): https://imagen.research.google/video/paper.pdf) """ - _compatibles = [e.name for e in KarrasDiffusionSchedulers] + _compatibles = karras_diffusion_scheduler_compatibles order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 565b7ff3c9c2..fc32b753fdfd 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -24,10 +24,10 @@ from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, - FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, + flax_karras_diffusion_scheduler_compatibles, get_velocity_common, ) @@ -101,7 +101,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): the `dtype` used for params and computation. """ - _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + _compatibles = flax_karras_diffusion_scheduler_compatibles dtype: jnp.dtype diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 9d8aa6fa5b2f..bbeeb945f8bc 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, randn_tensor -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin +from .scheduling_utils import SchedulerMixin, karras_diffusion_scheduler_compatibles @dataclass @@ -105,7 +105,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): https://imagen.research.google/video/paper.pdf) """ - _compatibles = [e.name for e in KarrasDiffusionSchedulers] + _compatibles = karras_diffusion_scheduler_compatibles order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index 3179538e8394..15e90780ced7 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -24,10 +24,10 @@ from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, - FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, + flax_karras_diffusion_scheduler_compatibles, get_velocity_common, ) @@ -84,7 +84,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): the `dtype` used for params and computation. """ - _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + _compatibles = flax_karras_diffusion_scheduler_compatibles dtype: jnp.dtype diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 1ad5480b7878..019fc8df4dd7 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -22,7 +22,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +from .scheduling_utils import SchedulerMixin, SchedulerOutput, karras_diffusion_scheduler_compatibles def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -105,7 +105,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): """ - _compatibles = [e.name for e in KarrasDiffusionSchedulers] + _compatibles = karras_diffusion_scheduler_compatibles order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 0630ea1d1fe7..b087f1f39c7b 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -21,7 +21,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +from .scheduling_utils import SchedulerMixin, SchedulerOutput, karras_diffusion_scheduler_compatibles def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -116,7 +116,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): """ - _compatibles = [e.name for e in KarrasDiffusionSchedulers] + _compatibles = karras_diffusion_scheduler_compatibles order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py index cadf782fb3ae..d41bb39a6152 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -24,10 +24,10 @@ from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, - FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, + flax_karras_diffusion_scheduler_compatibles, ) @@ -139,7 +139,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): the `dtype` used for params and computation. """ - _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + _compatibles = flax_karras_diffusion_scheduler_compatibles dtype: jnp.dtype diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 0225d8027bc3..95ca3418856e 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -21,7 +21,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +from .scheduling_utils import SchedulerMixin, SchedulerOutput, karras_diffusion_scheduler_compatibles def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -115,7 +115,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): """ - _compatibles = [e.name for e in KarrasDiffusionSchedulers] + _compatibles = karras_diffusion_scheduler_compatibles order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 45f939aafe70..8c8ebd85fe7a 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging, randn_tensor -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin +from .scheduling_utils import SchedulerMixin, karras_diffusion_scheduler_compatibles logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -71,7 +71,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ - _compatibles = [e.name for e in KarrasDiffusionSchedulers] + _compatibles = karras_diffusion_scheduler_compatibles order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 1a7a46bc5d32..126f0eaf82cb 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging, randn_tensor -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin +from .scheduling_utils import SchedulerMixin, karras_diffusion_scheduler_compatibles logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -74,7 +74,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): [`"linear"`, `"log_linear"`]. """ - _compatibles = [e.name for e in KarrasDiffusionSchedulers] + _compatibles = karras_diffusion_scheduler_compatibles order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 0dea944b6fef..1571259a12be 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -18,7 +18,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +from .scheduling_utils import SchedulerMixin, SchedulerOutput, karras_diffusion_scheduler_compatibles class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): @@ -47,7 +47,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): https://imagen.research.google/video/paper.pdf) """ - _compatibles = [e.name for e in KarrasDiffusionSchedulers] + _compatibles = karras_diffusion_scheduler_compatibles order = 2 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index 711bdf2d5ef0..09b954287f81 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -19,7 +19,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import randn_tensor -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +from .scheduling_utils import SchedulerMixin, SchedulerOutput, karras_diffusion_scheduler_compatibles class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): @@ -49,7 +49,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): https://imagen.research.google/video/paper.pdf) """ - _compatibles = [e.name for e in KarrasDiffusionSchedulers] + _compatibles = karras_diffusion_scheduler_compatibles order = 2 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index a46cc060522c..31717a75ea83 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -18,7 +18,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +from .scheduling_utils import SchedulerMixin, SchedulerOutput, karras_diffusion_scheduler_compatibles class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): @@ -48,7 +48,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): https://imagen.research.google/video/paper.pdf) """ - _compatibles = [e.name for e in KarrasDiffusionSchedulers] + _compatibles = karras_diffusion_scheduler_compatibles order = 2 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 88537a32df53..ace084f75792 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -21,7 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin +from .scheduling_utils import SchedulerMixin, karras_diffusion_scheduler_compatibles @dataclass @@ -69,7 +69,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): https://imagen.research.google/video/paper.pdf) """ - _compatibles = [e.name for e in KarrasDiffusionSchedulers] + _compatibles = karras_diffusion_scheduler_compatibles order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index e105ded997d2..70c481867fc5 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -22,10 +22,10 @@ from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, - FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left, + flax_karras_diffusion_scheduler_compatibles, ) @@ -82,7 +82,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): the `dtype` used for params and computation. """ - _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + _compatibles = flax_karras_diffusion_scheduler_compatibles dtype: jnp.dtype diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 065a07e955f8..e445ccd40f71 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -21,7 +21,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +from .scheduling_utils import SchedulerMixin, SchedulerOutput, karras_diffusion_scheduler_compatibles def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -91,7 +91,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): """ - _compatibles = [e.name for e in KarrasDiffusionSchedulers] + _compatibles = karras_diffusion_scheduler_compatibles order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 572da534643b..beca741ed639 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -24,10 +24,10 @@ from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, - FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, + flax_karras_diffusion_scheduler_compatibles, ) @@ -110,7 +110,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): the `dtype` used for params and computation. """ - _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + _compatibles = flax_karras_diffusion_scheduler_compatibles dtype: jnp.dtype pndm_order: int diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index f4103d4d62cc..f2f4d2026e9e 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -14,7 +14,6 @@ import importlib import os from dataclasses import dataclass -from enum import Enum from typing import Any, Dict, Optional, Union import torch @@ -24,20 +23,21 @@ SCHEDULER_CONFIG_NAME = "scheduler_config.json" - -class KarrasDiffusionSchedulers(Enum): - DDIMScheduler = 1 - DDPMScheduler = 2 - PNDMScheduler = 3 - LMSDiscreteScheduler = 4 - EulerDiscreteScheduler = 5 - HeunDiscreteScheduler = 6 - EulerAncestralDiscreteScheduler = 7 - DPMSolverMultistepScheduler = 8 - DPMSolverSinglestepScheduler = 9 - KDPM2DiscreteScheduler = 10 - KDPM2AncestralDiscreteScheduler = 11 - DEISMultistepScheduler = 12 +# NOTE keep in sync with ./__init__.py `KarrasDiffusionSchedulers` +karras_diffusion_scheduler_compatibles = [ + "DDIMScheduler", + "DDPMScheduler", + "PNDMScheduler", + "LMSDiscreteScheduler", + "EulerDiscreteScheduler", + "HeunDiscreteScheduler", + "EulerAncestralDiscreteScheduler", + "DPMSolverMultistepScheduler", + "DPMSolverSinglestepScheduler", + "KDPM2DiscreteScheduler", + "KDPM2AncestralDiscreteScheduler", + "DEISMultistepScheduler", +] @dataclass diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py index 9708c0883760..c685b4715cae 100644 --- a/src/diffusers/schedulers/scheduling_utils_flax.py +++ b/src/diffusers/schedulers/scheduling_utils_flax.py @@ -15,7 +15,6 @@ import math import os from dataclasses import dataclass -from enum import Enum from typing import Any, Dict, Optional, Tuple, Union import flax @@ -27,12 +26,14 @@ SCHEDULER_CONFIG_NAME = "scheduler_config.json" -class FlaxKarrasDiffusionSchedulers(Enum): - FlaxDDIMScheduler = 1 - FlaxDDPMScheduler = 2 - FlaxPNDMScheduler = 3 - FlaxLMSDiscreteScheduler = 4 - FlaxDPMSolverMultistepScheduler = 5 +# NOTE keep in sync with ./__init__.py `FlaxKarrasDiffusionSchedulers` +flax_karras_diffusion_scheduler_compatibles = [ + "FlaxDDIMScheduler", + "FlaxDDPMScheduler", + "FlaxPNDMScheduler", + "FlaxLMSDiscreteScheduler", + "FlaxDPMSolverMultistepScheduler", +] @dataclass