diff --git a/generative/networks/schedulers/ddim.py b/generative/networks/schedulers/ddim.py index d1395113..3f4ac0c6 100644 --- a/generative/networks/schedulers/ddim.py +++ b/generative/networks/schedulers/ddim.py @@ -192,12 +192,13 @@ def step( # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output elif self.prediction_type == "sample": pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output - # predict V - model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample # 4. Clip "predicted x_0" if self.clip_sample: @@ -209,7 +210,7 @@ def step( std_dev_t = eta * variance ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon # 7. compute x_t-1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction