Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions generative/networks/schedulers/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ class DDIMScheduler(nn.Module):
steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
prediction_type: prediction type of the scheduler function, one of `epsilon` (predicting the noise of the
diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
prediction_type: {``"epsilon"``, ``"sample"``, ``"v_prediction"``}
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""

Expand Down
4 changes: 4 additions & 0 deletions generative/networks/schedulers/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class DDPMScheduler(nn.Module):
variance_type: {``"fixed_small"``, ``"fixed_large"``, ``"learned"``, ``"learned_range"``}
options to clip the variance used when adding noise to the denoised sample.
clip_sample: option to clip predicted sample between -1 and 1 for numerical stability.
prediction_type: {``"epsilon"``, ``"sample"``, ``"v_prediction"``}
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""

def __init__(
Expand Down
12 changes: 12 additions & 0 deletions generative/networks/schedulers/pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class PNDMScheduler(nn.Module):
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
prediction_type: {``"epsilon"``, ``"v_prediction"``}
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
steps_offset:
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
Expand All @@ -69,6 +73,7 @@ def __init__(
beta_schedule: str = "linear",
skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
prediction_type: str = "epsilon",
steps_offset: int = 0,
) -> None:
super().__init__()
Expand All @@ -83,6 +88,10 @@ def __init__(
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

if prediction_type.lower() not in ["epsilon", "v_prediction"]:
raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon` or `v_prediction`")

self.prediction_type = prediction_type
self.num_train_timesteps = num_train_timesteps
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
Expand Down Expand Up @@ -294,6 +303,9 @@ def _get_prev_sample(self, sample: torch.Tensor, timestep: int, prev_timestep: i
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

if self.prediction_type == "v_prediction":
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample

# corresponds to (α_(t−δ) - α_t) divided by
# denominator of x_t in formula (9) and plus 1
# Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
Expand Down