diff --git a/generative/networks/schedulers/ddim.py b/generative/networks/schedulers/ddim.py index 7c3de648..eb26b2fd 100644 --- a/generative/networks/schedulers/ddim.py +++ b/generative/networks/schedulers/ddim.py @@ -65,6 +65,7 @@ class DDIMScheduler(Scheduler): set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the value of alpha at step 0. + A similar approach is used for reverse steps, setting this option to `True` will use zero as the first alpha. steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. @@ -96,6 +97,10 @@ def __init__( # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + # For reverse steps, we require the next alphas_cumprod. Similary to above, the first step doesn't + # have a next value so we can either set it to zero or use the second value. + self.first_alpha_cumprod = torch.tensor(0.0) if set_alpha_to_one else self.alphas_cumprod[-1] + # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 @@ -234,7 +239,7 @@ def reversed_step( sample: current instance of sample being created by diffusion process. Returns: - pred_prev_sample: Predicted previous sample + pred_next_sample: Predicted next sample pred_original_sample: Predicted original sample """ # See Appendix F at https://arxiv.org/pdf/2105.05233.pdf, or Equation (6) in https://arxiv.org/pdf/2203.04306.pdf @@ -245,14 +250,14 @@ def reversed_step( # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" - # - pred_post_sample -> "x_t+1" + # - pred_next_sample -> "x_t+1" - # 1. get previous step value (=t+1) - prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps + # 1. get next step value (=t+1) + next_timestep = timestep + self.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas at timestep t+1 alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + alpha_prod_t_next = self.alphas_cumprod[next_timestep] if next_timestep < len(self.alphas_cumprod) else self.first_alpha_cumprod beta_prod_t = 1 - alpha_prod_t @@ -274,9 +279,9 @@ def reversed_step( pred_original_sample = torch.clamp(pred_original_sample, -1, 1) # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon + pred_sample_direction = (1 - alpha_prod_t_next) ** (0.5) * pred_epsilon # 6. compute x_t+1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + pred_next_sample = alpha_prod_t_next ** (0.5) * pred_original_sample + pred_sample_direction - return pred_post_sample, pred_original_sample + return pred_next_sample, pred_original_sample