diff --git a/generative/networks/schedulers/__init__.py b/generative/networks/schedulers/__init__.py index bb2eb347..29e9020d 100644 --- a/generative/networks/schedulers/__init__.py +++ b/generative/networks/schedulers/__init__.py @@ -14,3 +14,4 @@ from .ddim import DDIMScheduler from .ddpm import DDPMScheduler from .pndm import PNDMScheduler +from .scheduler import NoiseSchedules, Scheduler diff --git a/generative/networks/schedulers/ddim.py b/generative/networks/schedulers/ddim.py index 7f155dbb..7c3de648 100644 --- a/generative/networks/schedulers/ddim.py +++ b/generative/networks/schedulers/ddim.py @@ -33,10 +33,26 @@ import numpy as np import torch -import torch.nn as nn +from monai.utils import StrEnum +from .scheduler import Scheduler -class DDIMScheduler(nn.Module): + +class DDIMPredictionType(StrEnum): + """ + Set of valid prediction type names for the DDIM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + sample: directly predicting the noisy sample + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + EPSILON = "epsilon" + SAMPLE = "sample" + V_PREDICTION = "v_prediction" + + +class DDIMScheduler(Scheduler): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. Based on: Song et al. "Denoising Diffusion @@ -44,10 +60,7 @@ class DDIMScheduler(nn.Module): Args: num_train_timesteps: number of diffusion steps used to train the model. - beta_start: the starting `beta` value of inference. - beta_end: the final `beta` value. - beta_schedule: {``"linear"``, ``"scaled_linear"``} - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is @@ -55,44 +68,27 @@ class DDIMScheduler(nn.Module): steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. - prediction_type: {``"epsilon"``, ``"sample"``, ``"v_prediction"``} - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 - https://imagen.research.google/video/paper.pdf) + prediction_type: member of DDPMPredictionType + schedule_args: arguments to pass to the schedule function + """ def __init__( self, num_train_timesteps: int = 1000, - beta_start: float = 1e-4, - beta_end: float = 2e-2, - beta_schedule: str = "linear", + schedule: str = "linear_beta", clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, - prediction_type: str = "epsilon", + prediction_type: str = DDIMPredictionType.EPSILON, + **schedule_args, ) -> None: - super().__init__() - self.beta_schedule = beta_schedule - if beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + super().__init__(num_train_timesteps, schedule, **schedule_args) - if prediction_type.lower() not in ["epsilon", "sample", "v_prediction"]: - raise ValueError( - f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" - ) + if prediction_type not in DDIMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of DDIMPredictionType") self.prediction_type = prediction_type - self.num_train_timesteps = num_train_timesteps - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 @@ -103,13 +99,13 @@ def __init__( # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64)) + self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64)) self.clip_sample = clip_sample self.steps_offset = steps_offset # default the number of inference timesteps to the number of train steps - self.set_timesteps(num_train_timesteps) + self.set_timesteps(self.num_train_timesteps) def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: """ @@ -190,13 +186,13 @@ def step( # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - if self.prediction_type == "epsilon": - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + if self.prediction_type == DDIMPredictionType.EPSILON: + pred_original_sample = (sample - (beta_prod_t**0.5) * model_output) / (alpha_prod_t**0.5) pred_epsilon = model_output - elif self.prediction_type == "sample": + elif self.prediction_type == DDIMPredictionType.SAMPLE: pred_original_sample = model_output - pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) - elif self.prediction_type == "v_prediction": + pred_epsilon = (sample - (alpha_prod_t**0.5) * pred_original_sample) / (beta_prod_t**0.5) + elif self.prediction_type == DDIMPredictionType.V_PREDICTION: pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample @@ -207,19 +203,19 @@ def step( # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(timestep, prev_timestep) - std_dev_t = eta * variance ** (0.5) + std_dev_t = eta * variance**0.5 # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * pred_epsilon # 7. compute x_t-1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + pred_prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction if eta > 0: # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 device = model_output.device if torch.is_tensor(model_output) else "cpu" noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) - variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise + variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise pred_prev_sample = pred_prev_sample + variance @@ -263,13 +259,13 @@ def reversed_step( # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - if self.prediction_type == "epsilon": + if self.prediction_type == DDIMPredictionType.EPSILON: pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output - elif self.prediction_type == "sample": + elif self.prediction_type == DDIMPredictionType.SAMPLE: pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) - elif self.prediction_type == "v_prediction": + elif self.prediction_type == DDIMPredictionType.V_PREDICTION: pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample @@ -284,50 +280,3 @@ def reversed_step( pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction return pred_post_sample, pred_original_sample - - def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: - """ - Add noise to the original samples. - - Args: - original_samples: original samples - noise: noise to add to samples - timesteps: timesteps tensor indicating the timestep to be computed for each sample. - - Returns: - noisy_samples: sample with added noise - """ - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_cumprod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_cumprod = sqrt_alpha_cumprod.flatten() - while len(sqrt_alpha_cumprod.shape) < len(original_samples.shape): - sqrt_alpha_cumprod = sqrt_alpha_cumprod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples - - def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: - # Make sure alphas_cumprod and timestep have same device and dtype as sample - self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) - timesteps = timesteps.to(sample.device) - - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(sample.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample - return velocity diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index 2f25f9f1..e543502c 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -33,10 +33,38 @@ import numpy as np import torch -import torch.nn as nn +from monai.utils import StrEnum +from .scheduler import Scheduler -class DDPMScheduler(nn.Module): + +class DDPMVarianceType(StrEnum): + """ + Valid names for DDPM Scheduler's `variance_type` argument. Options to clip the variance used when adding noise + to the denoised sample. + """ + + FIXED_SMALL = "fixed_small" + FIXED_LARGE = "fixed_large" + LEARNED = "learned" + LEARNED_RANGE = "learned_range" + + +class DDPMPredictionType(StrEnum): + """ + Set of valid prediction type names for the DDPM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + sample: directly predicting the noisy sample + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + EPSILON = "epsilon" + SAMPLE = "sample" + V_PREDICTION = "v_prediction" + + +class DDPMScheduler(Scheduler): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. Based on: Ho et al., "Denoising Diffusion Probabilistic Models" @@ -44,59 +72,33 @@ class DDPMScheduler(nn.Module): Args: num_train_timesteps: number of diffusion steps used to train the model. - beta_start: the starting `beta` value of inference. - beta_end: the final `beta` value. - beta_schedule: {``"linear"``, ``"scaled_linear"``} - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. - variance_type: {``"fixed_small"``, ``"fixed_large"``, ``"learned"``, ``"learned_range"``} - options to clip the variance used when adding noise to the denoised sample. + schedule: member of NoiseSchedules, name of noise schedule function in component store + variance_type: member of DDPMVarianceType clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. - prediction_type: {``"epsilon"``, ``"sample"``, ``"v_prediction"``} - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 - https://imagen.research.google/video/paper.pdf) + prediction_type: member of DDPMPredictionType + schedule_args: arguments to pass to the schedule function """ def __init__( self, num_train_timesteps: int = 1000, - beta_start: float = 1e-4, - beta_end: float = 2e-2, - beta_schedule: str = "linear", - variance_type: str = "fixed_small", + schedule: str = "linear_beta", + variance_type: str = DDPMVarianceType.FIXED_SMALL, clip_sample: bool = True, - prediction_type: str = "epsilon", + prediction_type: str = DDPMPredictionType.EPSILON, + **schedule_args, ) -> None: - super().__init__() - self.beta_schedule = beta_schedule - if beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + super().__init__(num_train_timesteps, schedule, **schedule_args) - if prediction_type.lower() not in ["epsilon", "sample", "v_prediction"]: - raise ValueError( - f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" - ) - - self.prediction_type = prediction_type + if variance_type not in DDPMVarianceType.__members__.values(): + raise ValueError("Argument `variance_type` must be a member of `DDPMVarianceType`") - self.num_train_timesteps = num_train_timesteps - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.one = torch.tensor(1.0) + if prediction_type not in DDPMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of `DDPMPredictionType`") self.clip_sample = clip_sample self.variance_type = variance_type - - # settable values - self.num_inference_steps = None - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + self.prediction_type = prediction_type def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: """ @@ -164,13 +166,13 @@ def _get_variance(self, timestep: int, predicted_variance: torch.Tensor | None = # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep] # hacks - were probably added for training stability - if self.variance_type == "fixed_small": + if self.variance_type == DDPMVarianceType.FIXED_SMALL: variance = torch.clamp(variance, min=1e-20) - elif self.variance_type == "fixed_large": + elif self.variance_type == DDPMVarianceType.FIXED_LARGE: variance = self.betas[timestep] - elif self.variance_type == "learned": + elif self.variance_type == DDPMVarianceType.LEARNED: return predicted_variance - elif self.variance_type == "learned_range": + elif self.variance_type == DDPMVarianceType.LEARNED_RANGE: min_log = variance max_log = self.betas[timestep] frac = (predicted_variance + 1) / 2 @@ -207,11 +209,11 @@ def step( # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if self.prediction_type == "epsilon": + if self.prediction_type == DDPMPredictionType.EPSILON: pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif self.prediction_type == "sample": + elif self.prediction_type == DDPMPredictionType.SAMPLE: pred_original_sample = model_output - elif self.prediction_type == "v_prediction": + elif self.prediction_type == DDPMPredictionType.V_PREDICTION: pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output # 3. Clip "predicted x_0" @@ -238,50 +240,3 @@ def step( pred_prev_sample = pred_prev_sample + variance return pred_prev_sample, pred_original_sample - - def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: - """ - Add noise to the original samples. - - Args: - original_samples: original samples - noise: noise to add to samples - timesteps: timesteps tensor indicating the timestep to be computed for each sample. - - Returns: - noisy_samples: sample with added noise - """ - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_cumprod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_cumprod = sqrt_alpha_cumprod.flatten() - while len(sqrt_alpha_cumprod.shape) < len(original_samples.shape): - sqrt_alpha_cumprod = sqrt_alpha_cumprod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples - - def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: - # Make sure alphas_cumprod and timestep have same device and dtype as sample - self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) - timesteps = timesteps.to(sample.device) - - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(sample.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample - return velocity diff --git a/generative/networks/schedulers/pndm.py b/generative/networks/schedulers/pndm.py index 0a1f2018..b729315f 100644 --- a/generative/networks/schedulers/pndm.py +++ b/generative/networks/schedulers/pndm.py @@ -35,10 +35,24 @@ import numpy as np import torch -import torch.nn as nn +from monai.utils import StrEnum +from .scheduler import Scheduler -class PNDMScheduler(nn.Module): + +class PNDMPredictionType(StrEnum): + """ + Set of valid prediction type names for the PNDM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + EPSILON = "epsilon" + V_PREDICTION = "v_prediction" + + +class PNDMScheduler(Scheduler): """ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, namely Runge-Kutta method and a linear multi-step method. Based on: Liu et al., @@ -46,10 +60,7 @@ class PNDMScheduler(nn.Module): Args: num_train_timesteps: number of diffusion steps used to train the model. - beta_start: the starting `beta` value of inference. - beta_end: the final `beta` value. - beta_schedule: {``"linear"``, ``"scaled_linear"``} - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store skip_prk_steps: allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required before plms step. @@ -57,46 +68,30 @@ class PNDMScheduler(nn.Module): each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the value of alpha at step 0. - prediction_type: {``"epsilon"``, ``"v_prediction"``} - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process) or `v_prediction` (see section 2.4 - https://imagen.research.google/video/paper.pdf) + prediction_type: member of DDPMPredictionType steps_offset: an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. + schedule_args: arguments to pass to the schedule function """ def __init__( self, num_train_timesteps: int = 1000, - beta_start: float = 1e-4, - beta_end: float = 2e-2, - beta_schedule: str = "linear", + schedule: str = "linear_beta", skip_prk_steps: bool = False, set_alpha_to_one: bool = False, - prediction_type: str = "epsilon", + prediction_type: str = PNDMPredictionType.EPSILON, steps_offset: int = 0, + **schedule_args, ) -> None: - super().__init__() - self.beta_schedule = beta_schedule - if beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + super().__init__(num_train_timesteps, schedule, **schedule_args) - if prediction_type.lower() not in ["epsilon", "v_prediction"]: - raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon` or `v_prediction`") + if prediction_type not in PNDMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of PNDMPredictionType") self.prediction_type = prediction_type - self.num_train_timesteps = num_train_timesteps - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] @@ -117,8 +112,6 @@ def __init__( self.cur_sample = None self.ets = [] - self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() - # default the number of inference timesteps to the number of train steps self.set_timesteps(num_train_timesteps) @@ -302,7 +295,7 @@ def _get_prev_sample(self, sample: torch.Tensor, timestep: int, prev_timestep: i beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev - if self.prediction_type == "v_prediction": + if self.prediction_type == PNDMPredictionType.V_PREDICTION: model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample # corresponds to (α_(t−δ) - α_t) divided by @@ -322,32 +315,3 @@ def _get_prev_sample(self, sample: torch.Tensor, timestep: int, prev_timestep: i ) return prev_sample - - def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: - """ - Add noise to the original samples. - - Args: - original_samples: original samples - noise: noise to add to samples - timesteps: timesteps tensor indicating the timestep to be computed for each sample. - - Returns: - noisy_samples: sample with added noise - """ - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_cumprod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_cumprod = sqrt_alpha_cumprod.flatten() - while len(sqrt_alpha_cumprod.shape) < len(original_samples.shape): - sqrt_alpha_cumprod = sqrt_alpha_cumprod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples diff --git a/generative/networks/schedulers/scheduler.py b/generative/networks/schedulers/scheduler.py new file mode 100644 index 00000000..bf153b8b --- /dev/null +++ b/generative/networks/schedulers/scheduler.py @@ -0,0 +1,200 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + + +from __future__ import annotations + +import torch +import torch.nn as nn + +from generative.utils import ComponentStore, unsqueeze_right + +NoiseSchedules = ComponentStore("NoiseSchedules", "Functions to generate noise schedules") + + +@NoiseSchedules.add_def("linear_beta", "Linear beta schedule") +def _linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): + """ + Linear beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + + Returns: + betas: beta schedule tensor + """ + return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + + +@NoiseSchedules.add_def("scaled_linear_beta", "Scaled linear beta schedule") +def _scaled_linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): + """ + Scaled linear beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + + Returns: + betas: beta schedule tensor + """ + return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + +@NoiseSchedules.add_def("sigmoid_beta", "Sigmoid beta schedule") +def _sigmoid_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2, sig_range: float = 6): + """ + Sigmoid beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + sig_range: pos/neg range of sigmoid input, default 6 + + Returns: + betas: beta schedule tensor + """ + betas = torch.linspace(-sig_range, sig_range, num_train_timesteps) + return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + + +@NoiseSchedules.add_def("cosine", "Cosine schedule") +def _cosine_beta(num_train_timesteps: int, s: float = 8e-3): + """ + Cosine noise schedule, see https://arxiv.org/abs/2102.09672 + + Args: + num_train_timesteps: number of timesteps + s: smoothing factor, default 8e-3 (see referenced paper) + + Returns: + (betas, alphas, alpha_cumprod) values + """ + x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1) + alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod /= alphas_cumprod[0].item() + alphas = torch.clip(alphas_cumprod[1:] / alphas_cumprod[:-1], 0.0001, 0.9999) + betas = 1.0 - alphas + return betas, alphas, alphas_cumprod[:-1] + + +class Scheduler(nn.Module): + """ + Base class for other schedulers based on a noise schedule function. + + This class is meant as the base for other schedulers which implement their own way of sampling or stepping. Here + the class defines beta, alpha, and alpha_cumprod values from a noise schedule function named with `schedule`, + which is the name of a component in NoiseSchedules. These components must all be callables which return either + the beta schedule alone or a triple containing (betas, alphas, alphas_cumprod) values. New schedule functions + can be provided by using the NoiseSchedules.add_def, for example: + + .. code-block:: python + from generative.networks.schedulers import NoiseSchedules, DDPMScheduler + + @NoiseSchedules.add_def("my_beta_schedule", "Some description of your function") + def _beta_function(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): + return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + + scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="my_beta_schedule") + + All such functions should have an initial positional integer argument `num_train_timesteps` stating the number of + timesteps the schedule is for, otherwise any other arguments can be given which will be passed by keyword through + the constructor's `schedule_args` value. To see what noise functions are available, print the object NoiseSchedules + to get a listing of stored objects with their docstring descriptions. + + Note: in previous versions of the schedulers the argument `schedule_beta` was used to state the beta schedule + type, this now replaced with `schedule` and most names used with the previous argument now have "_beta" appended + to them, eg. 'schedule_beta="linear"' -> 'schedule="linear_beta"'. The `beta_start` and `beta_end` arguments are + still used for some schedules but these are provided as keyword arguments now. + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, + a named function returning the beta tensor or (betas, alphas, alphas_cumprod) triple + schedule_args: arguments to pass to the schedule function + """ + + def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear_beta", **schedule_args) -> None: + super().__init__() + schedule_args["num_train_timesteps"] = num_train_timesteps + noise_sched = NoiseSchedules[schedule](**schedule_args) + + # set betas, alphas, alphas_cumprod based off return value from noise function + if isinstance(noise_sched, tuple): + self.betas, self.alphas, self.alphas_cumprod = noise_sched + else: + self.betas = noise_sched + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + self.num_train_timesteps = num_train_timesteps + self.one = torch.tensor(1.0) + + # settable values + self.num_inference_steps = None + self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1) + + def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + """ + Add noise to the original samples. + + Args: + original_samples: original samples + noise: noise to add to samples + timesteps: timesteps tensor indicating the timestep to be computed for each sample. + + Returns: + noisy_samples: sample with added noise + """ + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_cumprod = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, original_samples.ndim) + sqrt_one_minus_alpha_prod = unsqueeze_right((1 - self.alphas_cumprod[timesteps]) ** 0.5, original_samples.ndim) + + noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, sample.ndim) + sqrt_one_minus_alpha_prod = unsqueeze_right((1 - self.alphas_cumprod[timesteps]) ** 0.5, sample.ndim) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity diff --git a/generative/utils/__init__.py b/generative/utils/__init__.py index be9d721b..08a1b9b3 100644 --- a/generative/utils/__init__.py +++ b/generative/utils/__init__.py @@ -11,4 +11,6 @@ from __future__ import annotations +from .component_store import ComponentStore from .enums import AdversarialIterationEvents, AdversarialKeys +from .misc import unsqueeze_left, unsqueeze_right diff --git a/generative/utils/component_store.py b/generative/utils/component_store.py new file mode 100644 index 00000000..31ad8460 --- /dev/null +++ b/generative/utils/component_store.py @@ -0,0 +1,117 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections import namedtuple +from keyword import iskeyword +from textwrap import dedent, indent +from typing import Any, Callable, Dict, Iterable, TypeVar + +T = TypeVar("T") + + +def is_variable(name): + """Returns True if `name` is a valid Python variable name and also not a keyword.""" + return name.isidentifier() and not iskeyword(name) + + +class ComponentStore: + """ + Represents a storage object for other objects (specifically functions) keyed to a name with a description. + + These objects act as global named places for storing components for objects parameterised by component names. + Typically this is functions although other objects can be added. Printing a component store will produce a + list of members along with their docstring information if present. + + Example: + + .. code-block:: python + + TestStore = ComponentStore("Test Store", "A test store for demo purposes") + + @TestStore.add_def("my_func_name", "Some description of your function") + def _my_func(a, b): + '''A description of your function here.''' + return a * b + + print(TestStore) # will print out name, description, and 'my_func_name' with the docstring + + func = TestStore["my_func_name"] + result = func(7, 6) + + """ + + _Component = namedtuple("Component", ("description", "value")) # internal value pair + + def __init__(self, name: str, description: str) -> None: + self.components: Dict[str, self._Component] = {} + self.name: str = name + self.description: str = description + + self.__doc__ = f"Component Store '{name}': {description}\n{self.__doc__ or ''}".strip() + + def add(self, name: str, desc: str, value: T) -> T: + """Store the object `value` under the name `name` with description `desc`.""" + if not is_variable(name): + raise ValueError("Name of component must be valid Python identifier") + + self.components[name] = self._Component(desc, value) + return value + + def add_def(self, name: str, desc: str) -> Callable: + """Returns a decorator which stores the decorated function under `name` with description `desc`.""" + + def deco(func): + """Decorator to add a function to a store.""" + return self.add(name, desc, func) + + return deco + + def __contains__(self, name: str) -> bool: + """Returns True if the given name is stored.""" + return name in self.components + + def __len__(self) -> int: + """Returns the number of stored components.""" + return len(self.components) + + def __iter__(self) -> Iterable: + """Yields name/component pairs.""" + for k, v in self.components.items(): + yield k, v.value + + def __str__(self): + result = f"Component Store '{self.name}': {self.description}\nAvailable components:" + for k, v in self.components.items(): + result += f"\n* {k}:" + + if hasattr(v.value, "__doc__"): + doc = indent(dedent(v.value.__doc__.lstrip("\n").rstrip()), " ") + result += f"\n{doc}\n" + else: + result += f" {v.description}" + + return result + + def __getattr__(self, name: str) -> Any: + """Returns the stored object under the given name.""" + if name in self.components: + return self.components[name].value + else: + return self.__getattribute__(name) + + def __getitem__(self, name: str) -> Any: + """Returns the stored object under the given name.""" + if name in self.components: + return self.components[name].value + else: + raise ValueError(f"Component '{name}' not found") diff --git a/generative/utils/misc.py b/generative/utils/misc.py new file mode 100644 index 00000000..aea74a81 --- /dev/null +++ b/generative/utils/misc.py @@ -0,0 +1,26 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TypeVar + +T = TypeVar("T") + + +def unsqueeze_right(arr: T, ndim: int) -> T: + """Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" + return arr[(...,) + (None,) * (ndim - arr.ndim)] + + +def unsqueeze_left(arr: T, ndim: int) -> T: + """Preppend 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" + return arr[(None,) * (ndim - arr.ndim)] diff --git a/tests/test_component_store.py b/tests/test_component_store.py new file mode 100644 index 00000000..c6b43bde --- /dev/null +++ b/tests/test_component_store.py @@ -0,0 +1,72 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from generative.utils import ComponentStore + + +class TestComponentStore(unittest.TestCase): + def setUp(self): + self.cs = ComponentStore("TestStore", "I am a test store, please ignore") + + def test_empty(self): + self.assertEqual(len(self.cs), 0) + self.assertEqual(list(self.cs), []) + + def test_add(self): + test_obj = object() + + self.assertFalse("test_obj" in self.cs) + + self.cs.add("test_obj", "Test object", test_obj) + + self.assertTrue("test_obj" in self.cs) + + self.assertEqual(len(self.cs), 1) + self.assertEqual(list(self.cs), [("test_obj", test_obj)]) + + self.assertEqual(self.cs.test_obj, test_obj) + self.assertEqual(self.cs["test_obj"], test_obj) + + def test_add2(self): + test_obj1 = object() + test_obj2 = object() + + self.cs.add("test_obj1", "Test object", test_obj1) + self.cs.add("test_obj2", "Test object", test_obj2) + + self.assertEqual(len(self.cs), 2) + self.assertTrue("test_obj1" in self.cs) + self.assertTrue("test_obj2" in self.cs) + + def test_add_def(self): + self.assertFalse("test_func" in self.cs) + + @self.cs.add_def("test_func", "Test function") + def test_func(): + return 123 + + self.assertTrue("test_func" in self.cs) + + self.assertEqual(len(self.cs), 1) + self.assertEqual(list(self.cs), [("test_func", test_func)]) + + self.assertEqual(self.cs.test_func, test_func) + self.assertEqual(self.cs["test_func"], test_func) + + # try adding the same function again + self.cs.add_def("test_func", "Test function but with new description")(test_func) + + self.assertEqual(len(self.cs), 1) + self.assertEqual(self.cs.test_func, test_func) diff --git a/tests/test_compute_multiscalessim_metric.py b/tests/test_compute_multiscalessim_metric.py index 1f385fd4..85b96991 100644 --- a/tests/test_compute_multiscalessim_metric.py +++ b/tests/test_compute_multiscalessim_metric.py @@ -59,18 +59,22 @@ def test3d_gaussian(self): expected_value = 0.061796 self.assertTrue(expected_value - result.item() < 0.000001) - def input_ill_input_shape(self): + def input_ill_input_shape2d(self): + metric = MultiScaleSSIMMetric(spatial_dims=3, weights=[0.5, 0.5]) + with self.assertRaises(ValueError): - metric = MultiScaleSSIMMetric(spatial_dims=3, weights=[0.5, 0.5]) metric(torch.randn(1, 1, 64, 64), torch.randn(1, 1, 64, 64)) + def input_ill_input_shape3d(self): + metric = MultiScaleSSIMMetric(spatial_dims=2, weights=[0.5, 0.5]) + with self.assertRaises(ValueError): - metric = MultiScaleSSIMMetric(spatial_dims=2, weights=[0.5, 0.5]) metric(torch.randn(1, 1, 64, 64, 64), torch.randn(1, 1, 64, 64, 64)) def small_inputs(self): + metric = MultiScaleSSIMMetric(spatial_dims=2) + with self.assertRaises(ValueError): - metric = MultiScaleSSIMMetric(spatial_dims=2) metric(torch.randn(1, 1, 16, 16, 16), torch.randn(1, 1, 16, 16, 16)) diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index ebda9d31..1769040d 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -331,18 +331,19 @@ def test_with_conditioning_cross_attention_dim_none(self): ) def test_context_with_conditioning_none(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=False, + transformer_num_layers=1, + norm_num_groups=8, + ) + with self.assertRaises(ValueError): - net = DiffusionModelUNet( - spatial_dims=2, - in_channels=1, - out_channels=1, - num_res_blocks=1, - num_channels=(8, 8, 8), - attention_levels=(False, False, True), - with_conditioning=False, - transformer_num_layers=1, - norm_num_groups=8, - ) with eval_mode(net): net.forward( x=torch.rand((1, 1, 16, 32)), @@ -371,18 +372,19 @@ def test_shape_conditioned_models_class_conditioning(self): self.assertEqual(result.shape, (1, 1, 16, 32)) def test_conditioned_models_no_class_labels(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + with self.assertRaises(ValueError): - net = DiffusionModelUNet( - spatial_dims=2, - in_channels=1, - out_channels=1, - num_res_blocks=1, - num_channels=(8, 8, 8), - attention_levels=(False, False, True), - norm_num_groups=8, - num_head_channels=8, - num_class_embeds=2, - ) net.forward(x=torch.rand((1, 1, 16, 32)), timesteps=torch.randint(0, 1000, (1,)).long()) def test_model_num_channels_not_same_size_of_attention_levels(self): diff --git a/tests/test_misc.py b/tests/test_misc.py new file mode 100644 index 00000000..e0625321 --- /dev/null +++ b/tests/test_misc.py @@ -0,0 +1,47 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from generative.utils import unsqueeze_left, unsqueeze_right + +RIGHT_CASES = [(np.random.rand(3, 4), 5, (3, 4, 1, 1, 1)), (torch.rand(3, 4), 5, (3, 4, 1, 1, 1))] + +LEFT_CASES = [(np.random.rand(3, 4), 5, (1, 1, 1, 3, 4)), (torch.rand(3, 4), 5, (1, 1, 1, 3, 4))] + +ALL_CASES = [ + (np.random.rand(3, 4), 2, (3, 4)), + (np.random.rand(3, 4), 0, (3, 4)), + (np.random.rand(3, 4), -1, (3, 4)), + (np.array(3), 4, (1, 1, 1, 1)), + (np.array(3), 0, ()), + (torch.rand(3, 4), 2, (3, 4)), + (torch.rand(3, 4), 0, (3, 4)), + (torch.rand(3, 4), -1, (3, 4)), + (torch.tensor(3), 4, (1, 1, 1, 1)), + (torch.tensor(3), 0, ()), +] + + +class TestUnsqueeze(unittest.TestCase): + @parameterized.expand(RIGHT_CASES + ALL_CASES) + def test_unsqueeze_right(self, arr, ndim, shape): + self.assertEqual(unsqueeze_right(arr, ndim).shape, shape) + + @parameterized.expand(LEFT_CASES + ALL_CASES) + def test_unsqueeze_left(self, arr, ndim, shape): + self.assertEqual(unsqueeze_left(arr, ndim).shape, shape) diff --git a/tests/test_scheduler_ddim.py b/tests/test_scheduler_ddim.py index 67d773fe..3c64b42c 100644 --- a/tests/test_scheduler_ddim.py +++ b/tests/test_scheduler_ddim.py @@ -19,12 +19,12 @@ from generative.networks.schedulers import DDIMScheduler TEST_2D_CASE = [] -for beta_schedule in ["linear", "scaled_linear"]: - TEST_2D_CASE.append([{"beta_schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) TEST_3D_CASE = [] -for beta_schedule in ["linear", "scaled_linear"]: - TEST_3D_CASE.append([{"beta_schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) TEST_CASES = TEST_2D_CASE + TEST_3D_CASE diff --git a/tests/test_scheduler_ddpm.py b/tests/test_scheduler_ddpm.py index 7e07563e..835537fe 100644 --- a/tests/test_scheduler_ddpm.py +++ b/tests/test_scheduler_ddpm.py @@ -19,17 +19,17 @@ from generative.networks.schedulers import DDPMScheduler TEST_2D_CASE = [] -for beta_schedule in ["linear", "scaled_linear"]: +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: for variance_type in ["fixed_small", "fixed_large"]: TEST_2D_CASE.append( - [{"beta_schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16), (2, 6, 16, 16)] + [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16), (2, 6, 16, 16)] ) TEST_3D_CASE = [] -for beta_schedule in ["linear", "scaled_linear"]: +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: for variance_type in ["fixed_small", "fixed_large"]: TEST_3D_CASE.append( - [{"beta_schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)] + [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)] ) TEST_CASES = TEST_2D_CASE + TEST_3D_CASE @@ -55,6 +55,14 @@ def test_step_shape(self, input_param, input_shape, expected_shape): self.assertEqual(output_step[0].shape, expected_shape) self.assertEqual(output_step[1].shape, expected_shape) + @parameterized.expand(TEST_CASES) + def test_get_velocity_shape(self, input_param, input_shape, expected_shape): + scheduler = DDPMScheduler(**input_param) + sample = torch.randn(input_shape) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],)).long() + velocity = scheduler.get_velocity(sample=sample, noise=sample, timesteps=timesteps) + self.assertEqual(velocity.shape, expected_shape) + def test_step_learned(self): for variance_type in ["learned", "learned_range"]: scheduler = DDPMScheduler(variance_type=variance_type) diff --git a/tests/test_scheduler_pndm.py b/tests/test_scheduler_pndm.py index ee0cda29..4e0dbb97 100644 --- a/tests/test_scheduler_pndm.py +++ b/tests/test_scheduler_pndm.py @@ -19,12 +19,12 @@ from generative.networks.schedulers import PNDMScheduler TEST_2D_CASE = [] -for beta_schedule in ["linear", "scaled_linear"]: - TEST_2D_CASE.append([{"beta_schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) TEST_3D_CASE = [] -for beta_schedule in ["linear", "scaled_linear"]: - TEST_3D_CASE.append([{"beta_schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) TEST_CASES = TEST_2D_CASE + TEST_3D_CASE diff --git a/tutorials/generative/2d_ldm/2d_ldm_tutorial.ipynb b/tutorials/generative/2d_ldm/2d_ldm_tutorial.ipynb index cb4bd4b4..9a09dc95 100644 --- a/tutorials/generative/2d_ldm/2d_ldm_tutorial.ipynb +++ b/tutorials/generative/2d_ldm/2d_ldm_tutorial.ipynb @@ -851,7 +851,7 @@ " num_head_channels=(0, 256, 512),\n", ")\n", "\n", - "scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule=\"linear\", beta_start=0.0015, beta_end=0.0195)" + "scheduler = DDPMScheduler(num_train_timesteps=1000, schedule=\"linear_beta\", beta_start=0.0015, beta_end=0.0195)" ] }, { diff --git a/tutorials/generative/2d_ldm/2d_ldm_tutorial.py b/tutorials/generative/2d_ldm/2d_ldm_tutorial.py index 9face129..681c0a1e 100644 --- a/tutorials/generative/2d_ldm/2d_ldm_tutorial.py +++ b/tutorials/generative/2d_ldm/2d_ldm_tutorial.py @@ -310,7 +310,7 @@ num_head_channels=(0, 256, 512), ) -scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear", beta_start=0.0015, beta_end=0.0195) +scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="linear_beta", beta_start=0.0015, beta_end=0.0195) # - # ### Scaling factor diff --git a/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution.ipynb b/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution.ipynb index bad152f6..c6e58254 100644 --- a/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution.ipynb +++ b/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution.ipynb @@ -881,7 +881,7 @@ ")\n", "unet = unet.to(device)\n", "\n", - "scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule=\"linear\", beta_start=0.0015, beta_end=0.0195)" + "scheduler = DDPMScheduler(num_train_timesteps=1000, schedule=\"linear_beta\", beta_start=0.0015, beta_end=0.0195)" ] }, { @@ -899,7 +899,7 @@ "metadata": {}, "outputs": [], "source": [ - "low_res_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule=\"linear\", beta_start=0.0015, beta_end=0.0195)\n", + "low_res_scheduler = DDPMScheduler(num_train_timesteps=1000, schedule=\"linear_beta\", beta_start=0.0015, beta_end=0.0195)\n", "\n", "max_noise_level = 350" ] diff --git a/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution.py b/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution.py index 1234935d..82369fd7 100644 --- a/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution.py +++ b/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution.py @@ -323,13 +323,13 @@ ) unet = unet.to(device) -scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear", beta_start=0.0015, beta_end=0.0195) +scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="linear_beta", beta_start=0.0015, beta_end=0.0195) # %% [markdown] # As mentioned, we will use the conditioned augmentation (introduced in [2] section 3 and used on Stable Diffusion Upscalers and Imagen Video [3] Section 2.5) as it has been shown critical for cascaded diffusion models, as well for super-resolution tasks. For this, we apply Gaussian noise augmentation to the low-resolution images. We will use a scheduler `low_res_scheduler` to add this noise, with the `t` step defining the signal-to-noise ratio and use the `t` value to condition the diffusion model (inputted using `class_labels` argument). # %% -low_res_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear", beta_start=0.0015, beta_end=0.0195) +low_res_scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="linear_beta", beta_start=0.0015, beta_end=0.0195) max_noise_level = 350 diff --git a/tutorials/generative/3d_ddpm/3d_ddpm_tutorial.ipynb b/tutorials/generative/3d_ddpm/3d_ddpm_tutorial.ipynb index e5d0f2fb..1174e567 100644 --- a/tutorials/generative/3d_ddpm/3d_ddpm_tutorial.ipynb +++ b/tutorials/generative/3d_ddpm/3d_ddpm_tutorial.ipynb @@ -357,7 +357,7 @@ "metadata": {}, "outputs": [], "source": [ - "scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule=\"scaled_linear\", beta_start=0.0005, beta_end=0.0195)" + "scheduler = DDPMScheduler(num_train_timesteps=1000, schedule=\"scaled_linear_beta\", beta_start=0.0005, beta_end=0.0195)" ] }, { @@ -901,7 +901,7 @@ ], "source": [ "scheduler_ddim = DDIMScheduler(\n", - " num_train_timesteps=1000, beta_schedule=\"scaled_linear\", beta_start=0.0005, beta_end=0.0195, clip_sample=False\n", + " num_train_timesteps=1000, schedule=\"scaled_linear_beta\", beta_start=0.0005, beta_end=0.0195, clip_sample=False\n", ")\n", "\n", "scheduler_ddim.set_timesteps(num_inference_steps=250)\n", diff --git a/tutorials/generative/3d_ddpm/3d_ddpm_tutorial.py b/tutorials/generative/3d_ddpm/3d_ddpm_tutorial.py index 7ea85756..527b96d6 100644 --- a/tutorials/generative/3d_ddpm/3d_ddpm_tutorial.py +++ b/tutorials/generative/3d_ddpm/3d_ddpm_tutorial.py @@ -164,7 +164,7 @@ # Together with our U-net, we need to define the Noise Scheduler for the diffusion model. This scheduler is responsible for defining the amount of noise that should be added in each timestep `t` of the diffusion model's Markov chain. Besides that, it has the operations to perform the reverse process, which will remove the noise of the images (a.k.a. denoising process). In this case, we are using a `DDPMScheduler`. Here we are using 1000 timesteps and a `scaled_linear` profile for the beta values (proposed in [Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models"](https://arxiv.org/abs/2112.10752)). This profile had better results than the `linear, proposed in the original DDPM's paper. In `beta_start` and `beta_end`, we define the limits for the beta values. These are important to determine how accentuated is the addition of noise in the image. # %% -scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="scaled_linear", beta_start=0.0005, beta_end=0.0195) +scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="scaled_linear_beta", beta_start=0.0005, beta_end=0.0195) # %% plt.plot(scheduler.alphas_cumprod.cpu(), color=(2 / 255, 163 / 255, 163 / 255), linewidth=2) @@ -310,7 +310,7 @@ # %% scheduler_ddim = DDIMScheduler( - num_train_timesteps=1000, beta_schedule="scaled_linear", beta_start=0.0005, beta_end=0.0195, clip_sample=False + num_train_timesteps=1000, schedule="scaled_linear_beta", beta_start=0.0005, beta_end=0.0195, clip_sample=False ) scheduler_ddim.set_timesteps(num_inference_steps=250) diff --git a/tutorials/generative/3d_ldm/3d_ldm_tutorial.ipynb b/tutorials/generative/3d_ldm/3d_ldm_tutorial.ipynb index 48e96ffe..5e07974f 100644 --- a/tutorials/generative/3d_ldm/3d_ldm_tutorial.ipynb +++ b/tutorials/generative/3d_ldm/3d_ldm_tutorial.ipynb @@ -741,7 +741,7 @@ "unet.to(device)\n", "\n", "\n", - "scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule=\"scaled_linear\", beta_start=0.0015, beta_end=0.0195)" + "scheduler = DDPMScheduler(num_train_timesteps=1000, schedule=\"scaled_linear_beta\", beta_start=0.0015, beta_end=0.0195)" ] }, { diff --git a/tutorials/generative/3d_ldm/3d_ldm_tutorial.py b/tutorials/generative/3d_ldm/3d_ldm_tutorial.py index 0cf6a302..6ea8cfb0 100644 --- a/tutorials/generative/3d_ldm/3d_ldm_tutorial.py +++ b/tutorials/generative/3d_ldm/3d_ldm_tutorial.py @@ -308,7 +308,7 @@ def KL_loss(z_mu, z_sigma): unet.to(device) -scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="scaled_linear", beta_start=0.0015, beta_end=0.0195) +scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="scaled_linear_beta", beta_start=0.0015, beta_end=0.0195) # - # ### Scaling factor diff --git a/tutorials/generative/anomaly_detection/2d_classifierfree_guidance_anomalydetection_tutorial.py b/tutorials/generative/anomaly_detection/2d_classifierfree_guidance_anomalydetection_tutorial.py index 706b82c5..fb3d80b0 100644 --- a/tutorials/generative/anomaly_detection/2d_classifierfree_guidance_anomalydetection_tutorial.py +++ b/tutorials/generative/anomaly_detection/2d_classifierfree_guidance_anomalydetection_tutorial.py @@ -395,6 +395,7 @@ # %% [markdown] # ### Visualize anomaly map + # %% def visualize(img): _min = img.min() diff --git a/tutorials/generative/anomaly_detection/anomaly_detection_with_transformers.py b/tutorials/generative/anomaly_detection/anomaly_detection_with_transformers.py index fadae0bc..2db60a88 100644 --- a/tutorials/generative/anomaly_detection/anomaly_detection_with_transformers.py +++ b/tutorials/generative/anomaly_detection/anomaly_detection_with_transformers.py @@ -335,7 +335,6 @@ progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=110) progress_bar.set_description(f"Epoch {epoch}") for step, batch in progress_bar: - images = batch["image"].to(device) optimizer.zero_grad(set_to_none=True) @@ -358,7 +357,6 @@ val_loss = 0 with torch.no_grad(): for val_step, batch in enumerate(val_loader, start=1): - images = batch["image"].to(device) logits, quantizations_target, _ = inferer( diff --git a/tutorials/generative/image_to_image_translation/tutorial_segmentation_with_ddpm.py b/tutorials/generative/image_to_image_translation/tutorial_segmentation_with_ddpm.py index b7fd3d6e..eaa08f5b 100644 --- a/tutorials/generative/image_to_image_translation/tutorial_segmentation_with_ddpm.py +++ b/tutorials/generative/image_to_image_translation/tutorial_segmentation_with_ddpm.py @@ -368,7 +368,6 @@ def dice_coeff(im1, im2, empty_score=1.0): # + for i in range(len(ensemble)): - prediction = torch.where(ensemble[i] > 0.5, 1, 0).float() # a binary mask is obtained via thresholding score = dice_coeff( prediction[0, 0].cpu(), inputlabel.cpu()