From 2b24dba599fa2e9f306e6fc77a67f1a4a02a88f7 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 30 Sep 2022 14:58:08 +0200 Subject: [PATCH 01/13] Don't use `load_state_dict` if torch is not installed. --- src/diffusers/modeling_flax_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 80c3fb68a49a..1a3fea1e8a48 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -28,7 +28,6 @@ from requests import HTTPError from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax -from .modeling_utils import load_state_dict from .utils import ( CONFIG_NAME, DIFFUSERS_CACHE, @@ -37,6 +36,7 @@ WEIGHTS_NAME, logging, ) +from . import is_torch_available logger = logging.get_logger(__name__) @@ -391,6 +391,14 @@ def from_pretrained( ) if from_pt: + if is_torch_available(): + from .modeling_utils import load_state_dict + else: + raise EnvironmentError( + f"Can't load the model in PyTorch format because PyTorch is not installed. " + f"Please, install PyTorch or use native Flax weights." + ) + # Step 1: Get the pytorch file pytorch_model_file = load_state_dict(model_file) From f653140134b74d9ffec46d970eb46925fe3a409d Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 30 Sep 2022 15:00:50 +0200 Subject: [PATCH 02/13] Define `SchedulerOutput` to use torch or flax arrays. --- src/diffusers/schedulers/scheduling_utils.py | 41 ++++++++++++++------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 1cc1d94414a6..8e53c3239a03 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -14,26 +14,43 @@ import warnings from dataclasses import dataclass -import torch - from ..utils import BaseOutput +from ..utils import is_torch_available, is_flax_available SCHEDULER_CONFIG_NAME = "scheduler_config.json" -@dataclass -class SchedulerOutput(BaseOutput): - """ - Base class for the scheduler's step function output. +if is_torch_available(): + import torch - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - """ + @dataclass + class SchedulerOutput(BaseOutput): + """ + Base class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + +if is_flax_available(): + import jax.numpy as jnp + + class SchedulerOutput(BaseOutput): + """ + Base class for the scheduler's step function output. + + Args: + prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ - prev_sample: torch.FloatTensor + prev_sample: jnp.ndarray class SchedulerMixin: From c899dee2a05f5b761ef8def9c8de623d5db95afd Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 30 Sep 2022 15:02:04 +0200 Subject: [PATCH 03/13] Don't import LMSDiscreteScheduler without torch. --- src/diffusers/schedulers/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 495f30d9fabd..6189ec4c74fe 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -34,10 +34,12 @@ from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler + from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler else: from ..utils.dummy_flax_objects import * # noqa F403 -if is_scipy_available(): + +if is_scipy_available() and is_torch_available(): from .scheduling_lms_discrete import LMSDiscreteScheduler else: from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 From 7883e8e1a11ce76d6775c56c6d25585fa35c8a68 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 30 Sep 2022 15:49:11 +0200 Subject: [PATCH 04/13] Create distinct FlaxSchedulerOutput. --- .../schedulers/scheduling_ddim_flax.py | 14 ++--- .../schedulers/scheduling_ddpm_flax.py | 14 ++--- .../schedulers/scheduling_karras_ve_flax.py | 8 +-- .../scheduling_lms_discrete_flax.py | 14 ++--- .../schedulers/scheduling_pndm_flax.py | 26 +++++----- .../schedulers/scheduling_sde_ve_flax.py | 12 ++--- .../schedulers/scheduling_utils_flax.py | 52 +++++++++++++++++++ 7 files changed, 96 insertions(+), 44 deletions(-) create mode 100644 src/diffusers/schedulers/scheduling_utils_flax.py diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index d81d66607147..783f827bb483 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -23,7 +23,7 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils_flax import FlaxSchedulerOutput, FlaxSchedulerMixin def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: @@ -68,11 +68,11 @@ def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray): @dataclass -class FlaxSchedulerOutput(SchedulerOutput): +class DDIMSchedulerOutput(FlaxSchedulerOutput): state: DDIMSchedulerState -class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): +class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. @@ -183,7 +183,7 @@ def step( timestep: int, sample: jnp.ndarray, return_dict: bool = True, - ) -> Union[FlaxSchedulerOutput, Tuple]: + ) -> Union[DDIMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -197,10 +197,10 @@ def step( key (`random.KeyArray`): a PRNG key. eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): TODO - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class Returns: - [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + [`DDIMSchedulerOutput`] or `tuple`: [`DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -252,7 +252,7 @@ def step( if not return_dict: return (prev_sample, state) - return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) + return DDIMSchedulerOutput(prev_sample=prev_sample, state=state) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index 7c7b8d29ab52..085ea18e5687 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -23,7 +23,7 @@ from jax import random from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils_flax import FlaxSchedulerOutput, FlaxSchedulerMixin def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: @@ -67,11 +67,11 @@ def create(cls, num_train_timesteps: int): @dataclass -class FlaxSchedulerOutput(SchedulerOutput): +class DDPMSchedulerOutput(FlaxSchedulerOutput): state: DDPMSchedulerState -class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin): +class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. @@ -191,7 +191,7 @@ def step( key: random.KeyArray, predict_epsilon: bool = True, return_dict: bool = True, - ) -> Union[FlaxSchedulerOutput, Tuple]: + ) -> Union[DDPMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -205,10 +205,10 @@ def step( key (`random.KeyArray`): a PRNG key. predict_epsilon (`bool`): optional flag to use when model predicts the samples directly instead of the noise, epsilon. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class Returns: - [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + [`DDPMSchedulerOutput`] or `tuple`: [`DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -257,7 +257,7 @@ def step( if not return_dict: return (pred_prev_sample, state) - return FlaxSchedulerOutput(prev_sample=pred_prev_sample, state=state) + return DDPMSchedulerOutput(prev_sample=pred_prev_sample, state=state) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py index c320b79e6dcd..549356d329d2 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py @@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin +from .scheduling_utils_flax import FlaxSchedulerMixin @flax.struct.dataclass @@ -56,7 +56,7 @@ class FlaxKarrasVeOutput(BaseOutput): state: KarrasVeSchedulerState -class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin): +class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): """ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and the VE column of Table 1 from [1] for reference. @@ -170,7 +170,7 @@ def step( sigma_hat (`float`): TODO sigma_prev (`float`): TODO sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class Returns: [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion @@ -209,7 +209,7 @@ def step_correct( sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO derivative (`torch.FloatTensor` or `np.ndarray`): TODO - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class Returns: prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 4784e4fafccb..9a7c938795f8 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -20,7 +20,7 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils_flax import FlaxSchedulerOutput, FlaxSchedulerMixin @flax.struct.dataclass @@ -37,11 +37,11 @@ def create(cls, num_train_timesteps: int, sigmas: jnp.ndarray): @dataclass -class FlaxSchedulerOutput(SchedulerOutput): +class LMSSchedulerOutput(FlaxSchedulerOutput): state: LMSDiscreteSchedulerState -class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin): +class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): """ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by Katherine Crowson: @@ -145,7 +145,7 @@ def step( sample: jnp.ndarray, order: int = 4, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[LMSSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -157,10 +157,10 @@ def step( sample (`jnp.ndarray`): current instance of sample being created by diffusion process. order: coefficient for multi-step inference. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than LMSSchedulerOutput class Returns: - [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + [`LMSSchedulerOutput`] or `tuple`: [`LMSSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -187,7 +187,7 @@ def step( if not return_dict: return (prev_sample, state) - return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) + return LMSSchedulerOutput(prev_sample=prev_sample, state=state) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 9e2b19f01301..de3bf0a07bfe 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -23,7 +23,7 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: @@ -76,11 +76,11 @@ def create(cls, num_train_timesteps: int): @dataclass -class FlaxSchedulerOutput(SchedulerOutput): +class PNDMSchedulerOutput(FlaxSchedulerOutput): state: PNDMSchedulerState -class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): +class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): """ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, namely Runge-Kutta method and a linear multi-step method. @@ -211,7 +211,7 @@ def step( timestep: int, sample: jnp.ndarray, return_dict: bool = True, - ) -> Union[FlaxSchedulerOutput, Tuple]: + ) -> Union[PNDMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -224,10 +224,10 @@ def step( timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than PNDMSchedulerOutput class Returns: - [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + [`PNDMSchedulerOutput`] or `tuple`: [`PNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -249,7 +249,7 @@ def step( if not return_dict: return (prev_sample, state) - return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) + return PNDMSchedulerOutput(prev_sample=prev_sample, state=state) def step_prk( self, @@ -257,7 +257,7 @@ def step_prk( model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, - ) -> Union[FlaxSchedulerOutput, Tuple]: + ) -> Union[PNDMSchedulerOutput, Tuple]: """ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the solution to the differential equation. @@ -268,10 +268,10 @@ def step_prk( timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than PNDMSchedulerOutput class Returns: - [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + [`PNDMSchedulerOutput`] or `tuple`: [`PNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -327,7 +327,7 @@ def step_plms( model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, - ) -> Union[FlaxSchedulerOutput, Tuple]: + ) -> Union[PNDMSchedulerOutput, Tuple]: """ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple times to approximate the solution. @@ -338,10 +338,10 @@ def step_plms( timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than PNDMSchedulerOutput class Returns: - [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + [`PNDMSchedulerOutput`] or `tuple`: [`PNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py index 08fbe14732da..8eb3b0d63929 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -22,7 +22,7 @@ from jax import random from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput @flax.struct.dataclass @@ -38,7 +38,7 @@ def create(cls): @dataclass -class FlaxSdeVeOutput(SchedulerOutput): +class FlaxSdeVeOutput(FlaxSchedulerOutput): """ Output class for the ScoreSdeVeScheduler's step function output. @@ -56,7 +56,7 @@ class FlaxSdeVeOutput(SchedulerOutput): prev_sample_mean: Optional[jnp.ndarray] = None -class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): +class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): """ The variance exploding stochastic differential equation (SDE) scheduler. @@ -168,7 +168,7 @@ def step_pred( sample (`jnp.ndarray`): current instance of sample being created by diffusion process. generator: random number generator. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class Returns: [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When @@ -216,7 +216,7 @@ def step_correct( sample: jnp.ndarray, key: random.KeyArray, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[FlaxSdeVeOutput, Tuple]: """ Correct the predicted sample based on the output model_output of the network. This is often run repeatedly after making the prediction for the previous timestep. @@ -227,7 +227,7 @@ def step_correct( sample (`jnp.ndarray`): current instance of sample being created by diffusion process. generator: random number generator. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class Returns: [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py new file mode 100644 index 000000000000..5b52bc7c4284 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_utils_flax.py @@ -0,0 +1,52 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +import jax.numpy as jnp +from dataclasses import dataclass + +from ..utils import BaseOutput + + +SCHEDULER_CONFIG_NAME = "scheduler_config.json" + + + +class FlaxSchedulerOutput(BaseOutput): + """ + Base class for the scheduler's step function output. + + Args: + prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: jnp.ndarray + + +class FlaxSchedulerMixin: + """ + Mixin containing common functions for the schedulers. + """ + + config_name = SCHEDULER_CONFIG_NAME + + def set_format(self, tensor_format="pt"): + warnings.warn( + "The method `set_format` is deprecated and will be removed in version `0.5.0`." + "If you're running your code in PyTorch, you can safely remove this function as the schedulers" + "are always in Pytorch", + DeprecationWarning, + ) + return self From 5f84b67d4837552682a1aaeb1ff2bf688c9e686b Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 30 Sep 2022 16:01:18 +0000 Subject: [PATCH 05/13] Additional changes required for FlaxSchedulerMixin --- src/diffusers/__init__.py | 1 + src/diffusers/pipeline_flax_utils.py | 6 +++--- src/diffusers/schedulers/__init__.py | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index acdddaac4d26..bb9565bf178e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -74,6 +74,7 @@ FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, FlaxScoreSdeVeScheduler, + FlaxSchedulerMixin, ) else: from .utils.dummy_flax_objects import * # noqa F403 diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index 9ea94ee0f2e1..91a0aef18ef8 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -30,7 +30,7 @@ from .configuration_utils import ConfigMixin from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin -from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin +from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging @@ -46,7 +46,7 @@ LOADABLE_CLASSES = { "diffusers": { "FlaxModelMixin": ["save_pretrained", "from_pretrained"], - "SchedulerMixin": ["save_config", "from_config"], + "FlaxSchedulerMixin": ["save_config", "from_config"], "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"], }, "transformers": { @@ -436,7 +436,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P else: loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False) params[name] = loaded_params - elif issubclass(class_obj, SchedulerMixin): + elif issubclass(class_obj, FlaxSchedulerMixin): loaded_sub_model, scheduler_state = load_method(loadable_folder) params[name] = scheduler_state else: diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 6189ec4c74fe..5658dbde46df 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -35,6 +35,7 @@ from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler + from .scheduling_utils_flax import FlaxSchedulerMixin else: from ..utils.dummy_flax_objects import * # noqa F403 From 912ca266bf7f64da665e07816717cb773d1ea8b4 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 30 Sep 2022 16:02:01 +0000 Subject: [PATCH 06/13] Do not import torch pipelines in Flax. --- src/diffusers/pipelines/__init__.py | 15 +++++++++------ .../pipelines/stable_diffusion/__init__.py | 4 ++-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 8e3c8592a258..8a84c16dbc05 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,11 +1,14 @@ from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available -from .ddim import DDIMPipeline -from .ddpm import DDPMPipeline -from .latent_diffusion_uncond import LDMPipeline -from .pndm import PNDMPipeline -from .score_sde_ve import ScoreSdeVePipeline -from .stochastic_karras_ve import KarrasVePipeline +if is_torch_available(): + from .ddim import DDIMPipeline + from .ddpm import DDPMPipeline + from .latent_diffusion_uncond import LDMPipeline + from .pndm import PNDMPipeline + from .score_sde_ve import ScoreSdeVePipeline + from .stochastic_karras_ve import KarrasVePipeline +else: + from ..utils.dummy_pt_objects import * # noqa F403 if is_torch_available() and is_transformers_available(): from .latent_diffusion import LDMTextToImagePipeline diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index e3b8e2f0f30c..5c1504e5e527 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -6,7 +6,7 @@ import PIL from PIL import Image -from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_transformers_available +from ...utils import BaseOutput, is_flax_available, is_torch_available, is_onnx_available, is_transformers_available @dataclass @@ -27,7 +27,7 @@ class StableDiffusionPipelineOutput(BaseOutput): nsfw_content_detected: List[bool] -if is_transformers_available(): +if is_transformers_available() and is_torch_available(): from .pipeline_stable_diffusion import StableDiffusionPipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline From a64ece7e8bc062acb9826f314d3b4225ae3c3188 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 30 Sep 2022 16:05:59 +0000 Subject: [PATCH 07/13] Revert "Define `SchedulerOutput` to use torch or flax arrays." This reverts commit f653140134b74d9ffec46d970eb46925fe3a409d. --- src/diffusers/schedulers/scheduling_utils.py | 41 ++++++-------------- 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 8e53c3239a03..1cc1d94414a6 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -14,43 +14,26 @@ import warnings from dataclasses import dataclass +import torch + from ..utils import BaseOutput -from ..utils import is_torch_available, is_flax_available SCHEDULER_CONFIG_NAME = "scheduler_config.json" -if is_torch_available(): - import torch - - @dataclass - class SchedulerOutput(BaseOutput): - """ - Base class for the scheduler's step function output. - - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - """ - - prev_sample: torch.FloatTensor - -if is_flax_available(): - import jax.numpy as jnp - - class SchedulerOutput(BaseOutput): - """ - Base class for the scheduler's step function output. +@dataclass +class SchedulerOutput(BaseOutput): + """ + Base class for the scheduler's step function output. - Args: - prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - """ + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ - prev_sample: jnp.ndarray + prev_sample: torch.FloatTensor class SchedulerMixin: From 4cd5431ed63bfaadcab488990ac8d14a685f4f31 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 30 Sep 2022 16:12:18 +0000 Subject: [PATCH 08/13] Prefix Flax scheduler outputs for consistency. --- .../schedulers/scheduling_ddim_flax.py | 10 ++++----- .../schedulers/scheduling_ddpm_flax.py | 10 ++++----- .../scheduling_lms_discrete_flax.py | 10 ++++----- .../schedulers/scheduling_pndm_flax.py | 22 +++++++++---------- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 783f827bb483..f6b886e2e59c 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -68,7 +68,7 @@ def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray): @dataclass -class DDIMSchedulerOutput(FlaxSchedulerOutput): +class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput): state: DDIMSchedulerState @@ -183,7 +183,7 @@ def step( timestep: int, sample: jnp.ndarray, return_dict: bool = True, - ) -> Union[DDIMSchedulerOutput, Tuple]: + ) -> Union[FlaxDDIMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -197,10 +197,10 @@ def step( key (`random.KeyArray`): a PRNG key. eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): TODO - return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class Returns: - [`DDIMSchedulerOutput`] or `tuple`: [`DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + [`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -252,7 +252,7 @@ def step( if not return_dict: return (prev_sample, state) - return DDIMSchedulerOutput(prev_sample=prev_sample, state=state) + return FlaxDDIMSchedulerOutput(prev_sample=prev_sample, state=state) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index 085ea18e5687..aa562535f0fa 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -67,7 +67,7 @@ def create(cls, num_train_timesteps: int): @dataclass -class DDPMSchedulerOutput(FlaxSchedulerOutput): +class FlaxDDPMSchedulerOutput(FlaxSchedulerOutput): state: DDPMSchedulerState @@ -191,7 +191,7 @@ def step( key: random.KeyArray, predict_epsilon: bool = True, return_dict: bool = True, - ) -> Union[DDPMSchedulerOutput, Tuple]: + ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -205,10 +205,10 @@ def step( key (`random.KeyArray`): a PRNG key. predict_epsilon (`bool`): optional flag to use when model predicts the samples directly instead of the noise, epsilon. - return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class Returns: - [`DDPMSchedulerOutput`] or `tuple`: [`DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + [`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -257,7 +257,7 @@ def step( if not return_dict: return (pred_prev_sample, state) - return DDPMSchedulerOutput(prev_sample=pred_prev_sample, state=state) + return FlaxDDPMSchedulerOutput(prev_sample=pred_prev_sample, state=state) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 9a7c938795f8..37d04f8b8364 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -37,7 +37,7 @@ def create(cls, num_train_timesteps: int, sigmas: jnp.ndarray): @dataclass -class LMSSchedulerOutput(FlaxSchedulerOutput): +class FlaxLMSSchedulerOutput(FlaxSchedulerOutput): state: LMSDiscreteSchedulerState @@ -145,7 +145,7 @@ def step( sample: jnp.ndarray, order: int = 4, return_dict: bool = True, - ) -> Union[LMSSchedulerOutput, Tuple]: + ) -> Union[FlaxLMSSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -157,10 +157,10 @@ def step( sample (`jnp.ndarray`): current instance of sample being created by diffusion process. order: coefficient for multi-step inference. - return_dict (`bool`): option for returning tuple rather than LMSSchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class Returns: - [`LMSSchedulerOutput`] or `tuple`: [`LMSSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + [`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -187,7 +187,7 @@ def step( if not return_dict: return (prev_sample, state) - return LMSSchedulerOutput(prev_sample=prev_sample, state=state) + return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index de3bf0a07bfe..75688d731a7d 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -76,7 +76,7 @@ def create(cls, num_train_timesteps: int): @dataclass -class PNDMSchedulerOutput(FlaxSchedulerOutput): +class FlaxPNDMSchedulerOutput(FlaxSchedulerOutput): state: PNDMSchedulerState @@ -211,7 +211,7 @@ def step( timestep: int, sample: jnp.ndarray, return_dict: bool = True, - ) -> Union[PNDMSchedulerOutput, Tuple]: + ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -224,10 +224,10 @@ def step( timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than PNDMSchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class Returns: - [`PNDMSchedulerOutput`] or `tuple`: [`PNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -249,7 +249,7 @@ def step( if not return_dict: return (prev_sample, state) - return PNDMSchedulerOutput(prev_sample=prev_sample, state=state) + return FlaxPNDMSchedulerOutput(prev_sample=prev_sample, state=state) def step_prk( self, @@ -257,7 +257,7 @@ def step_prk( model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, - ) -> Union[PNDMSchedulerOutput, Tuple]: + ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: """ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the solution to the differential equation. @@ -268,10 +268,10 @@ def step_prk( timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than PNDMSchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class Returns: - [`PNDMSchedulerOutput`] or `tuple`: [`PNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -327,7 +327,7 @@ def step_plms( model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, - ) -> Union[PNDMSchedulerOutput, Tuple]: + ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: """ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple times to approximate the solution. @@ -338,10 +338,10 @@ def step_plms( timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than PNDMSchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class Returns: - [`PNDMSchedulerOutput`] or `tuple`: [`PNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ From ba25c1b5c728d4d3ae16d14e8d1af97f6f9825d2 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 30 Sep 2022 16:15:57 +0000 Subject: [PATCH 09/13] make style --- src/diffusers/__init__.py | 2 +- src/diffusers/modeling_flax_utils.py | 2 +- src/diffusers/pipelines/__init__.py | 1 + src/diffusers/pipelines/stable_diffusion/__init__.py | 2 +- src/diffusers/schedulers/__init__.py | 1 - src/diffusers/schedulers/scheduling_ddim_flax.py | 2 +- src/diffusers/schedulers/scheduling_ddpm_flax.py | 2 +- src/diffusers/schedulers/scheduling_lms_discrete_flax.py | 2 +- src/diffusers/schedulers/scheduling_utils_flax.py | 4 ++-- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index bb9565bf178e..1cf64a4a2ebf 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -73,8 +73,8 @@ FlaxKarrasVeScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, - FlaxScoreSdeVeScheduler, FlaxSchedulerMixin, + FlaxScoreSdeVeScheduler, ) else: from .utils.dummy_flax_objects import * # noqa F403 diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 1a3fea1e8a48..0cd58d602523 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -27,6 +27,7 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError +from . import is_torch_available from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax from .utils import ( CONFIG_NAME, @@ -36,7 +37,6 @@ WEIGHTS_NAME, logging, ) -from . import is_torch_available logger = logging.get_logger(__name__) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 8a84c16dbc05..1c31595fb0cf 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,5 +1,6 @@ from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available + if is_torch_available(): from .ddim import DDIMPipeline from .ddpm import DDPMPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 5c1504e5e527..615fa404da0b 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -6,7 +6,7 @@ import PIL from PIL import Image -from ...utils import BaseOutput, is_flax_available, is_torch_available, is_onnx_available, is_transformers_available +from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available @dataclass diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 5658dbde46df..a906c39eb24c 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -34,7 +34,6 @@ from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler - from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_utils_flax import FlaxSchedulerMixin else: from ..utils.dummy_flax_objects import * # noqa F403 diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index f6b886e2e59c..1c297abef97b 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -23,7 +23,7 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerOutput, FlaxSchedulerMixin +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index aa562535f0fa..9ac57250a7cd 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -23,7 +23,7 @@ from jax import random from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerOutput, FlaxSchedulerMixin +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 37d04f8b8364..b18a553643ff 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -20,7 +20,7 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerOutput, FlaxSchedulerMixin +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput @flax.struct.dataclass diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py index 5b52bc7c4284..00844b5aeb47 100644 --- a/src/diffusers/schedulers/scheduling_utils_flax.py +++ b/src/diffusers/schedulers/scheduling_utils_flax.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import warnings -import jax.numpy as jnp from dataclasses import dataclass +import jax.numpy as jnp + from ..utils import BaseOutput SCHEDULER_CONFIG_NAME = "scheduler_config.json" - class FlaxSchedulerOutput(BaseOutput): """ Base class for the scheduler's step function output. From d188843549ac992c5090c5864077256cf4fc6a87 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 30 Sep 2022 16:22:09 +0000 Subject: [PATCH 10/13] FlaxSchedulerOutput is now a dataclass. --- src/diffusers/schedulers/scheduling_utils_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py index 00844b5aeb47..ee14d87d78aa 100644 --- a/src/diffusers/schedulers/scheduling_utils_flax.py +++ b/src/diffusers/schedulers/scheduling_utils_flax.py @@ -21,7 +21,7 @@ SCHEDULER_CONFIG_NAME = "scheduler_config.json" - +@dataclass class FlaxSchedulerOutput(BaseOutput): """ Base class for the scheduler's step function output. From 2af55a2d955447ad1a566d36dc105f997e70b226 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 30 Sep 2022 16:22:24 +0000 Subject: [PATCH 11/13] Don't use f-string without placeholders. --- src/diffusers/modeling_flax_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 0cd58d602523..9e40edf263fd 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -395,8 +395,8 @@ def from_pretrained( from .modeling_utils import load_state_dict else: raise EnvironmentError( - f"Can't load the model in PyTorch format because PyTorch is not installed. " - f"Please, install PyTorch or use native Flax weights." + "Can't load the model in PyTorch format because PyTorch is not installed. " + "Please, install PyTorch or use native Flax weights." ) # Step 1: Get the pytorch file From cbabee303d16e8a211245294a0b2dc5a14cc2b06 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 30 Sep 2022 16:24:29 +0000 Subject: [PATCH 12/13] Add blank line. --- src/diffusers/schedulers/scheduling_utils_flax.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py index ee14d87d78aa..9e1024b7235c 100644 --- a/src/diffusers/schedulers/scheduling_utils_flax.py +++ b/src/diffusers/schedulers/scheduling_utils_flax.py @@ -21,6 +21,7 @@ SCHEDULER_CONFIG_NAME = "scheduler_config.json" + @dataclass class FlaxSchedulerOutput(BaseOutput): """ From ba4868ce41a52ff5a322f2577a9f4a2b53d8c3ce Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 3 Oct 2022 15:09:02 +0200 Subject: [PATCH 13/13] Style (docstrings) --- src/diffusers/schedulers/scheduling_ddim_flax.py | 4 ++-- src/diffusers/schedulers/scheduling_ddpm_flax.py | 4 ++-- .../schedulers/scheduling_lms_discrete_flax.py | 4 ++-- src/diffusers/schedulers/scheduling_pndm_flax.py | 12 ++++++------ 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 1c297abef97b..5f81da698719 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -200,8 +200,8 @@ def step( return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class Returns: - [`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. - When returning a tuple, the first element is the sample tensor. + [`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index 9ac57250a7cd..86679bd059a5 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -208,8 +208,8 @@ def step( return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class Returns: - [`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. - When returning a tuple, the first element is the sample tensor. + [`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. """ t = timestep diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index b18a553643ff..f08c9899c918 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -160,8 +160,8 @@ def step( return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class Returns: - [`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. - When returning a tuple, the first element is the sample tensor. + [`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. """ sigma = state.sigmas[timestep] diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 75688d731a7d..516051852b80 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -227,8 +227,8 @@ def step( return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class Returns: - [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. - When returning a tuple, the first element is the sample tensor. + [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.config.skip_prk_steps: @@ -271,8 +271,8 @@ def step_prk( return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class Returns: - [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. - When returning a tuple, the first element is the sample tensor. + [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: @@ -341,8 +341,8 @@ def step_plms( return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class Returns: - [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. - When returning a tuple, the first element is the sample tensor. + [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: