diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index acdddaac4d26..1cf64a4a2ebf 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -73,6 +73,7 @@ FlaxKarrasVeScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, + FlaxSchedulerMixin, FlaxScoreSdeVeScheduler, ) else: diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 80c3fb68a49a..9e40edf263fd 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -27,8 +27,8 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError +from . import is_torch_available from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax -from .modeling_utils import load_state_dict from .utils import ( CONFIG_NAME, DIFFUSERS_CACHE, @@ -391,6 +391,14 @@ def from_pretrained( ) if from_pt: + if is_torch_available(): + from .modeling_utils import load_state_dict + else: + raise EnvironmentError( + "Can't load the model in PyTorch format because PyTorch is not installed. " + "Please, install PyTorch or use native Flax weights." + ) + # Step 1: Get the pytorch file pytorch_model_file = load_state_dict(model_file) diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index 9ea94ee0f2e1..91a0aef18ef8 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -30,7 +30,7 @@ from .configuration_utils import ConfigMixin from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin -from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin +from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging @@ -46,7 +46,7 @@ LOADABLE_CLASSES = { "diffusers": { "FlaxModelMixin": ["save_pretrained", "from_pretrained"], - "SchedulerMixin": ["save_config", "from_config"], + "FlaxSchedulerMixin": ["save_config", "from_config"], "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"], }, "transformers": { @@ -436,7 +436,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P else: loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False) params[name] = loaded_params - elif issubclass(class_obj, SchedulerMixin): + elif issubclass(class_obj, FlaxSchedulerMixin): loaded_sub_model, scheduler_state = load_method(loadable_folder) params[name] = scheduler_state else: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 8e3c8592a258..1c31595fb0cf 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,12 +1,16 @@ from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available -from .ddim import DDIMPipeline -from .ddpm import DDPMPipeline -from .latent_diffusion_uncond import LDMPipeline -from .pndm import PNDMPipeline -from .score_sde_ve import ScoreSdeVePipeline -from .stochastic_karras_ve import KarrasVePipeline +if is_torch_available(): + from .ddim import DDIMPipeline + from .ddpm import DDPMPipeline + from .latent_diffusion_uncond import LDMPipeline + from .pndm import PNDMPipeline + from .score_sde_ve import ScoreSdeVePipeline + from .stochastic_karras_ve import KarrasVePipeline +else: + from ..utils.dummy_pt_objects import * # noqa F403 + if is_torch_available() and is_transformers_available(): from .latent_diffusion import LDMTextToImagePipeline from .stable_diffusion import ( diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index e3b8e2f0f30c..615fa404da0b 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -6,7 +6,7 @@ import PIL from PIL import Image -from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_transformers_available +from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available @dataclass @@ -27,7 +27,7 @@ class StableDiffusionPipelineOutput(BaseOutput): nsfw_content_detected: List[bool] -if is_transformers_available(): +if is_transformers_available() and is_torch_available(): from .pipeline_stable_diffusion import StableDiffusionPipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 495f30d9fabd..a906c39eb24c 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -34,10 +34,12 @@ from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler + from .scheduling_utils_flax import FlaxSchedulerMixin else: from ..utils.dummy_flax_objects import * # noqa F403 -if is_scipy_available(): + +if is_scipy_available() and is_torch_available(): from .scheduling_lms_discrete import LMSDiscreteScheduler else: from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index d81d66607147..5f81da698719 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -23,7 +23,7 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: @@ -68,11 +68,11 @@ def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray): @dataclass -class FlaxSchedulerOutput(SchedulerOutput): +class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput): state: DDIMSchedulerState -class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): +class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. @@ -183,7 +183,7 @@ def step( timestep: int, sample: jnp.ndarray, return_dict: bool = True, - ) -> Union[FlaxSchedulerOutput, Tuple]: + ) -> Union[FlaxDDIMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -197,11 +197,11 @@ def step( key (`random.KeyArray`): a PRNG key. eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): TODO - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class Returns: - [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. - When returning a tuple, the first element is the sample tensor. + [`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: @@ -252,7 +252,7 @@ def step( if not return_dict: return (prev_sample, state) - return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) + return FlaxDDIMSchedulerOutput(prev_sample=prev_sample, state=state) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index 7c7b8d29ab52..86679bd059a5 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -23,7 +23,7 @@ from jax import random from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: @@ -67,11 +67,11 @@ def create(cls, num_train_timesteps: int): @dataclass -class FlaxSchedulerOutput(SchedulerOutput): +class FlaxDDPMSchedulerOutput(FlaxSchedulerOutput): state: DDPMSchedulerState -class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin): +class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. @@ -191,7 +191,7 @@ def step( key: random.KeyArray, predict_epsilon: bool = True, return_dict: bool = True, - ) -> Union[FlaxSchedulerOutput, Tuple]: + ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -205,11 +205,11 @@ def step( key (`random.KeyArray`): a PRNG key. predict_epsilon (`bool`): optional flag to use when model predicts the samples directly instead of the noise, epsilon. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class Returns: - [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. - When returning a tuple, the first element is the sample tensor. + [`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. """ t = timestep @@ -257,7 +257,7 @@ def step( if not return_dict: return (pred_prev_sample, state) - return FlaxSchedulerOutput(prev_sample=pred_prev_sample, state=state) + return FlaxDDPMSchedulerOutput(prev_sample=pred_prev_sample, state=state) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py index c320b79e6dcd..549356d329d2 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py @@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin +from .scheduling_utils_flax import FlaxSchedulerMixin @flax.struct.dataclass @@ -56,7 +56,7 @@ class FlaxKarrasVeOutput(BaseOutput): state: KarrasVeSchedulerState -class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin): +class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): """ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and the VE column of Table 1 from [1] for reference. @@ -170,7 +170,7 @@ def step( sigma_hat (`float`): TODO sigma_prev (`float`): TODO sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class Returns: [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion @@ -209,7 +209,7 @@ def step_correct( sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO derivative (`torch.FloatTensor` or `np.ndarray`): TODO - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class Returns: prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 4784e4fafccb..f08c9899c918 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -20,7 +20,7 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput @flax.struct.dataclass @@ -37,11 +37,11 @@ def create(cls, num_train_timesteps: int, sigmas: jnp.ndarray): @dataclass -class FlaxSchedulerOutput(SchedulerOutput): +class FlaxLMSSchedulerOutput(FlaxSchedulerOutput): state: LMSDiscreteSchedulerState -class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin): +class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): """ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by Katherine Crowson: @@ -145,7 +145,7 @@ def step( sample: jnp.ndarray, order: int = 4, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[FlaxLMSSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -157,11 +157,11 @@ def step( sample (`jnp.ndarray`): current instance of sample being created by diffusion process. order: coefficient for multi-step inference. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class Returns: - [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. - When returning a tuple, the first element is the sample tensor. + [`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. """ sigma = state.sigmas[timestep] @@ -187,7 +187,7 @@ def step( if not return_dict: return (prev_sample, state) - return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) + return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 9e2b19f01301..516051852b80 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -23,7 +23,7 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: @@ -76,11 +76,11 @@ def create(cls, num_train_timesteps: int): @dataclass -class FlaxSchedulerOutput(SchedulerOutput): +class FlaxPNDMSchedulerOutput(FlaxSchedulerOutput): state: PNDMSchedulerState -class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): +class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): """ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, namely Runge-Kutta method and a linear multi-step method. @@ -211,7 +211,7 @@ def step( timestep: int, sample: jnp.ndarray, return_dict: bool = True, - ) -> Union[FlaxSchedulerOutput, Tuple]: + ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -224,11 +224,11 @@ def step( timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class Returns: - [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. - When returning a tuple, the first element is the sample tensor. + [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.config.skip_prk_steps: @@ -249,7 +249,7 @@ def step( if not return_dict: return (prev_sample, state) - return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) + return FlaxPNDMSchedulerOutput(prev_sample=prev_sample, state=state) def step_prk( self, @@ -257,7 +257,7 @@ def step_prk( model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, - ) -> Union[FlaxSchedulerOutput, Tuple]: + ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: """ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the solution to the differential equation. @@ -268,11 +268,11 @@ def step_prk( timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class Returns: - [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. - When returning a tuple, the first element is the sample tensor. + [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: @@ -327,7 +327,7 @@ def step_plms( model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, - ) -> Union[FlaxSchedulerOutput, Tuple]: + ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: """ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple times to approximate the solution. @@ -338,11 +338,11 @@ def step_plms( timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class Returns: - [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. - When returning a tuple, the first element is the sample tensor. + [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py index 08fbe14732da..8eb3b0d63929 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -22,7 +22,7 @@ from jax import random from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput @flax.struct.dataclass @@ -38,7 +38,7 @@ def create(cls): @dataclass -class FlaxSdeVeOutput(SchedulerOutput): +class FlaxSdeVeOutput(FlaxSchedulerOutput): """ Output class for the ScoreSdeVeScheduler's step function output. @@ -56,7 +56,7 @@ class FlaxSdeVeOutput(SchedulerOutput): prev_sample_mean: Optional[jnp.ndarray] = None -class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): +class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): """ The variance exploding stochastic differential equation (SDE) scheduler. @@ -168,7 +168,7 @@ def step_pred( sample (`jnp.ndarray`): current instance of sample being created by diffusion process. generator: random number generator. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class Returns: [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When @@ -216,7 +216,7 @@ def step_correct( sample: jnp.ndarray, key: random.KeyArray, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[FlaxSdeVeOutput, Tuple]: """ Correct the predicted sample based on the output model_output of the network. This is often run repeatedly after making the prediction for the previous timestep. @@ -227,7 +227,7 @@ def step_correct( sample (`jnp.ndarray`): current instance of sample being created by diffusion process. generator: random number generator. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class Returns: [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py new file mode 100644 index 000000000000..9e1024b7235c --- /dev/null +++ b/src/diffusers/schedulers/scheduling_utils_flax.py @@ -0,0 +1,53 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from dataclasses import dataclass + +import jax.numpy as jnp + +from ..utils import BaseOutput + + +SCHEDULER_CONFIG_NAME = "scheduler_config.json" + + +@dataclass +class FlaxSchedulerOutput(BaseOutput): + """ + Base class for the scheduler's step function output. + + Args: + prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: jnp.ndarray + + +class FlaxSchedulerMixin: + """ + Mixin containing common functions for the schedulers. + """ + + config_name = SCHEDULER_CONFIG_NAME + + def set_format(self, tensor_format="pt"): + warnings.warn( + "The method `set_format` is deprecated and will be removed in version `0.5.0`." + "If you're running your code in PyTorch, you can safely remove this function as the schedulers" + "are always in Pytorch", + DeprecationWarning, + ) + return self