diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index dd3c2ac85d2c..015b79b2780d 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -96,7 +96,13 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. set_alpha_to_one (`bool`, default `True`): - if alpha for final step is 1 or the final alpha of the "non-previous" one. + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ @register_to_config @@ -109,6 +115,7 @@ def __init__( trained_betas: Optional[jnp.ndarray] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, + steps_offset: int = 0, ): if trained_betas is not None: self.betas = jnp.asarray(trained_betas) @@ -144,9 +151,7 @@ def _get_variance(self, timestep, prev_timestep): return variance - def set_timesteps( - self, state: DDIMSchedulerState, num_inference_steps: int, offset: int = 0 - ) -> DDIMSchedulerState: + def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int) -> DDIMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -155,9 +160,9 @@ def set_timesteps( the `FlaxDDIMScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. - offset (`int`): - optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference. """ + offset = self.config.steps_offset + step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 @@ -263,9 +268,14 @@ def add_noise( timesteps: jnp.ndarray, ) -> jnp.ndarray: sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod[:, None] + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[:, None] noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index f686a2a32234..9096663016c2 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -266,9 +266,14 @@ def add_noise( timesteps: jnp.ndarray, ) -> jnp.ndarray: sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod[..., None] + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None] noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 1431bdacf54c..7f4c076b54d1 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -198,8 +198,11 @@ def add_noise( noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: - sigmas = self.match_shape(state.sigmas[timesteps], noise) - noisy_samples = original_samples + noise * sigmas + sigma = state.sigmas[timesteps].flatten() + while len(sigma.shape) < len(noise.shape): + sigma = sigma[..., None] + + noisy_samples = original_samples + noise * sigma return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 8444d6680401..efc3858ca75a 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math - # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import math from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -59,7 +59,6 @@ class PNDMSchedulerState: # setable values _timesteps: jnp.ndarray num_inference_steps: Optional[int] = None - _offset: int = 0 prk_timesteps: Optional[jnp.ndarray] = None plms_timesteps: Optional[jnp.ndarray] = None timesteps: Optional[jnp.ndarray] = None @@ -104,6 +103,14 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): skip_prk_steps (`bool`): allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required before plms steps; defaults to `False`. + set_alpha_to_one (`bool`, default `False`): + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ @register_to_config @@ -115,6 +122,8 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, skip_prk_steps: bool = False, + set_alpha_to_one: bool = False, + steps_offset: int = 0, ): if trained_betas is not None: self.betas = jnp.asarray(trained_betas) @@ -132,6 +141,8 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # mainly at formula (9), (12), (13) and the Algorithm 2. @@ -139,9 +150,7 @@ def __init__( self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps) - def set_timesteps( - self, state: PNDMSchedulerState, num_inference_steps: int, offset: int = 0 - ) -> PNDMSchedulerState: + def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -150,16 +159,15 @@ def set_timesteps( the `FlaxPNDMScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. - offset (`int`): - optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference. """ + offset = self.config.steps_offset + step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # rounding to avoid issues when num_inference_step is power of 3 - _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] - _timesteps = _timesteps + offset + _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + offset - state = state.replace(num_inference_steps=num_inference_steps, _offset=offset, _timesteps=_timesteps) + state = state.replace(num_inference_steps=num_inference_steps, _timesteps=_timesteps) if self.config.skip_prk_steps: # for some models like stable diffusion the prk steps can/should be skipped to @@ -254,7 +262,7 @@ def step_prk( ) diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2 - prev_timestep = max(timestep - diff_to_prev, state.prk_timesteps[-1]) + prev_timestep = timestep - diff_to_prev timestep = state.prk_timesteps[state.counter // 4 * 4] if state.counter % 4 == 0: @@ -274,7 +282,7 @@ def step_prk( # cur_sample should not be `None` cur_sample = state.cur_sample if state.cur_sample is not None else sample - prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output, state=state) + prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) state = state.replace(counter=state.counter + 1) if not return_dict: @@ -320,7 +328,7 @@ def step_plms( "for more information." ) - prev_timestep = max(timestep - self.config.num_train_timesteps // state.num_inference_steps, 0) + prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps if state.counter != 1: state = state.replace(ets=state.ets.append(model_output)) @@ -344,7 +352,7 @@ def step_plms( 55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4] ) - prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output, state=state) + prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) state = state.replace(counter=state.counter + 1) if not return_dict: @@ -352,7 +360,7 @@ def step_plms( return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) - def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state): + def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # this function computes x_(t−δ) using the formula of (9) # Note that x_t needs to be added to both sides of the equation @@ -365,8 +373,8 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state) # sample -> x_t # model_output -> e_θ(x_t, t) # prev_sample -> x_(t−δ) - alpha_prod_t = self.alphas_cumprod[timestep + 1 - state._offset] - alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - state._offset] + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev @@ -395,9 +403,14 @@ def add_noise( timesteps: jnp.ndarray, ) -> jnp.ndarray: sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod[..., None] + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None] noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py index e5860706aa2e..08fbe14732da 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -192,14 +192,17 @@ def step_pred( # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) # also equation 47 shows the analog from SDE models to ancestral sampling methods - drift = drift - diffusion[:, None, None, None] ** 2 * model_output + diffusion = diffusion.flatten() + while len(diffusion.shape) < len(sample.shape): + diffusion = diffusion[:, None] + drift = drift - diffusion**2 * model_output # equation 6: sample noise for the diffusion term of key = random.split(key, num=1) noise = random.normal(key=key, shape=sample.shape) prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep # TODO is the variable diffusion the correct scaling term for the noise? - prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g + prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g if not return_dict: return (prev_sample, prev_sample_mean, state) @@ -248,8 +251,11 @@ def step_correct( step_size = step_size * jnp.ones(sample.shape[0]) # compute corrected sample: model_output term and noise term - prev_sample_mean = sample + step_size[:, None, None, None] * model_output - prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise + step_size = step_size.flatten() + while len(step_size.shape) < len(sample.shape): + step_size = step_size[:, None] + prev_sample_mean = sample + step_size * model_output + prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise if not return_dict: return (prev_sample, state)