Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions docs/source/en/api/schedulers/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,3 @@ The base class [`SchedulerMixin`] implements low level utilities used by multipl
The class [`SchedulerOutput`] contains the outputs from any schedulers `step(...)` call.

[[autodoc]] schedulers.scheduling_utils.SchedulerOutput

### KarrasDiffusionSchedulers

[[autodoc]] schedulers.scheduling_utils.KarrasDiffusionSchedulers
Comment on lines -83 to -86
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doc builder doesn't know how to autodoc Enums. I think it's ok to remove

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add it manually?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's what I'm thinking : #2346 (comment)

31 changes: 29 additions & 2 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.


from typing import Union

from ..utils import OptionalDependencyNotAvailable, is_flax_available, is_scipy_available, is_torch_available


Expand All @@ -39,7 +41,7 @@
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_unclip import UnCLIPScheduler
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from .scheduling_utils import SchedulerMixin
from .scheduling_vq_diffusion import VQDiffusionScheduler

try:
Expand All @@ -56,7 +58,6 @@
from .scheduling_pndm_flax import FlaxPNDMScheduler
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
from .scheduling_utils_flax import (
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
Expand All @@ -70,3 +71,29 @@
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403
else:
from .scheduling_lms_discrete import LMSDiscreteScheduler


# NOTE keep in sync with ./scheduling_utils.py `karras_diffusion_scheduler_compatibles`
KarrasDiffusionSchedulers = Union[
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
DEISMultistepScheduler,
]

# NOTE keep in sync with ./scheduling_utils_flax.py `flax_karras_diffusion_scheduler_compatibles`
FlaxKarrasDiffusionSchedulers = Union[
FlaxDDIMScheduler,
FlaxDDPMScheduler,
FlaxPNDMScheduler,
FlaxLMSDiscreteScheduler,
FlaxDPMSolverMultistepScheduler,
]
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from .scheduling_utils import SchedulerMixin, karras_diffusion_scheduler_compatibles


@dataclass
Expand Down Expand Up @@ -112,7 +112,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf)
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_compatibles = karras_diffusion_scheduler_compatibles
order = 1

@register_to_config
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_ddim_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
add_noise_common,
flax_karras_diffusion_scheduler_compatibles,
get_velocity_common,
)

Expand Down Expand Up @@ -101,7 +101,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
the `dtype` used for params and computation.
"""

_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_compatibles = flax_karras_diffusion_scheduler_compatibles

dtype: jnp.dtype

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from .scheduling_utils import SchedulerMixin, karras_diffusion_scheduler_compatibles


@dataclass
Expand Down Expand Up @@ -105,7 +105,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf)
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_compatibles = karras_diffusion_scheduler_compatibles
order = 1

@register_to_config
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_ddpm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
add_noise_common,
flax_karras_diffusion_scheduler_compatibles,
get_velocity_common,
)

Expand Down Expand Up @@ -84,7 +84,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
the `dtype` used for params and computation.
"""

_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_compatibles = flax_karras_diffusion_scheduler_compatibles

dtype: jnp.dtype

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_deis_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from .scheduling_utils import SchedulerMixin, SchedulerOutput, karras_diffusion_scheduler_compatibles


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
Expand Down Expand Up @@ -105,7 +105,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):

"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_compatibles = karras_diffusion_scheduler_compatibles
order = 1

@register_to_config
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from .scheduling_utils import SchedulerMixin, SchedulerOutput, karras_diffusion_scheduler_compatibles


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
Expand Down Expand Up @@ -116,7 +116,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):

"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_compatibles = karras_diffusion_scheduler_compatibles
order = 1

@register_to_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
add_noise_common,
flax_karras_diffusion_scheduler_compatibles,
)


Expand Down Expand Up @@ -139,7 +139,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
the `dtype` used for params and computation.
"""

_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_compatibles = flax_karras_diffusion_scheduler_compatibles

dtype: jnp.dtype

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from .scheduling_utils import SchedulerMixin, SchedulerOutput, karras_diffusion_scheduler_compatibles


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
Expand Down Expand Up @@ -115,7 +115,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):

"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_compatibles = karras_diffusion_scheduler_compatibles
order = 1

@register_to_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from .scheduling_utils import SchedulerMixin, karras_diffusion_scheduler_compatibles


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -71,7 +71,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):

"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_compatibles = karras_diffusion_scheduler_compatibles
order = 1

@register_to_config
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from .scheduling_utils import SchedulerMixin, karras_diffusion_scheduler_compatibles


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -74,7 +74,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
[`"linear"`, `"log_linear"`].
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_compatibles = karras_diffusion_scheduler_compatibles
order = 1

@register_to_config
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_heun_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from .scheduling_utils import SchedulerMixin, SchedulerOutput, karras_diffusion_scheduler_compatibles


class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
Expand Down Expand Up @@ -47,7 +47,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf)
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_compatibles = karras_diffusion_scheduler_compatibles
order = 2

@register_to_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from .scheduling_utils import SchedulerMixin, SchedulerOutput, karras_diffusion_scheduler_compatibles


class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
Expand Down Expand Up @@ -49,7 +49,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf)
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_compatibles = karras_diffusion_scheduler_compatibles
order = 2

@register_to_config
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from .scheduling_utils import SchedulerMixin, SchedulerOutput, karras_diffusion_scheduler_compatibles


class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
Expand Down Expand Up @@ -48,7 +48,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf)
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_compatibles = karras_diffusion_scheduler_compatibles
order = 2

@register_to_config
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from .scheduling_utils import SchedulerMixin, karras_diffusion_scheduler_compatibles


@dataclass
Expand Down Expand Up @@ -69,7 +69,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf)
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_compatibles = karras_diffusion_scheduler_compatibles
order = 1

@register_to_config
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_lms_discrete_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
flax_karras_diffusion_scheduler_compatibles,
)


Expand Down Expand Up @@ -82,7 +82,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
the `dtype` used for params and computation.
"""

_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_compatibles = flax_karras_diffusion_scheduler_compatibles

dtype: jnp.dtype

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from .scheduling_utils import SchedulerMixin, SchedulerOutput, karras_diffusion_scheduler_compatibles


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
Expand Down Expand Up @@ -91,7 +91,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):

"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_compatibles = karras_diffusion_scheduler_compatibles
order = 1

@register_to_config
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_pndm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
add_noise_common,
flax_karras_diffusion_scheduler_compatibles,
)


Expand Down Expand Up @@ -110,7 +110,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
the `dtype` used for params and computation.
"""

_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_compatibles = flax_karras_diffusion_scheduler_compatibles

dtype: jnp.dtype
pndm_order: int
Expand Down
Loading