Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions generative/networks/schedulers/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class DDIMScheduler(Scheduler):
set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one.
For the final step there is no previous alpha. When this option is `True` the previous alpha product is
fixed to `1`, otherwise it uses the value of alpha at step 0.
A similar approach is used for reverse steps, setting this option to `True` will use zero as the first alpha.
steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
Expand Down Expand Up @@ -96,6 +97,10 @@ def __init__(
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]

# For reverse steps, we require the next alphas_cumprod. Similary to above, the first step doesn't
# have a next value so we can either set it to zero or use the second value.
self.first_alpha_cumprod = torch.tensor(0.0) if set_alpha_to_one else self.alphas_cumprod[-1]

# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0

Expand Down Expand Up @@ -234,7 +239,7 @@ def reversed_step(
sample: current instance of sample being created by diffusion process.

Returns:
pred_prev_sample: Predicted previous sample
pred_next_sample: Predicted next sample
pred_original_sample: Predicted original sample
"""
# See Appendix F at https://arxiv.org/pdf/2105.05233.pdf, or Equation (6) in https://arxiv.org/pdf/2203.04306.pdf
Expand All @@ -245,14 +250,14 @@ def reversed_step(
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_post_sample -> "x_t+1"
# - pred_next_sample -> "x_t+1"

# 1. get previous step value (=t+1)
prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps
# 1. get next step value (=t+1)
next_timestep = timestep + self.num_train_timesteps // self.num_inference_steps

# 2. compute alphas, betas at timestep t+1
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
alpha_prod_t_next = self.alphas_cumprod[next_timestep] if next_timestep < len(self.alphas_cumprod) else self.first_alpha_cumprod

beta_prod_t = 1 - alpha_prod_t

Expand All @@ -274,9 +279,9 @@ def reversed_step(
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)

# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon
pred_sample_direction = (1 - alpha_prod_t_next) ** (0.5) * pred_epsilon

# 6. compute x_t+1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
pred_next_sample = alpha_prod_t_next ** (0.5) * pred_original_sample + pred_sample_direction

return pred_post_sample, pred_original_sample
return pred_next_sample, pred_original_sample