From f675b395943107a79525a5a7cd5bd0536eef8d2e Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 2 Dec 2022 10:36:19 -0600 Subject: [PATCH 01/42] Fixes return type in sample --- generative/inferers/inferer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index d7918e79..13a07e8d 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -66,7 +66,7 @@ def sample( intermediate_steps: Optional[int] = 100, conditioning: Optional[torch.Tensor] = None, verbose: Optional[bool] = True, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: """ Args: input_noise: random noise, of the same shape as the desired sample. From ec2f1c11eb0835469426430aa82238c0bcf585f9 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 2 Dec 2022 17:14:07 -0600 Subject: [PATCH 02/42] Adds method to compute posterior mean --- generative/schedulers/ddpm.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/generative/schedulers/ddpm.py b/generative/schedulers/ddpm.py index f2ea54e2..510c6438 100644 --- a/generative/schedulers/ddpm.py +++ b/generative/schedulers/ddpm.py @@ -82,7 +82,7 @@ def __init__( self.clip_sample = clip_sample self.variance_type = variance_type - # setable values + # settable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) @@ -101,9 +101,33 @@ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, to ].copy() self.timesteps = torch.from_numpy(timesteps).to(device) + def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: + """ + Compute the mean of the posterior at timestep t. + + Args: + timestep: current timestep. + x0: the noise-free input. + x_t: the input noised to timestep t. + + Returns: + Returns the mean + """ + # these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0), + # (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf) + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + + x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t) + x_t_coefficient = alpha_prod_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) + + mean = x_0_coefficient * x_0 + x_t_coefficient * x_t + + return mean + def _get_variance(self, timestep: int, predicted_variance: Optional[torch.Tensor] = None) -> torch.Tensor: """ - Compute the variance. + Compute the variance of the posterior at timestep t. Args: timestep: current timestep. From 8e68dd9ad448ed2bf53b9e93669abf42bc8bddcb Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 2 Dec 2022 17:14:44 -0600 Subject: [PATCH 03/42] Initial code for computing likelihood --- generative/inferers/inferer.py | 70 ++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 13a07e8d..044f0475 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -12,6 +12,7 @@ from typing import Callable, List, Optional, Tuple, Union +import numpy as np import torch import torch.nn as nn from monai.inferers import Inferer @@ -101,6 +102,75 @@ def sample( else: return image + def get_likelihood( + self, + inputs: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + scheduler: Optional[Callable[..., torch.Tensor]] = None, + save_intermediates: Optional[bool] = False, + conditioning: Optional[torch.Tensor] = None, + verbose: Optional[bool] = True, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: + """ + Computes the likelihoods for an input. + + Args: + inputs: input images, NxCxHxW[xD] + diffusion_model: model to compute likelihood from + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: save the + conditioning: + verbose: if true, prints the progression bar of the sampling process. + """ + + if not scheduler: + scheduler = self.scheduler + if scheduler._get_name() != "DDPMScheduler": + raise NotImplementedError( + f"Likelihood computation is only compatible with DDPMScheduler," + f" you are using {scheduler._get_name()}" + ) + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + noise = torch.randn_like(inputs).to(inputs.device) + total_kl = torch.zeros_like(inputs) + for t in progress_bar: + timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) + # get the model's predicted mean and variance if it is predicted + if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + predicted_mean, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) + else: + predicted_mean = model_output + predicted_variance = None + # get the posterior mean and variance + posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) + log_posterior_variance = torch.log(posterior_variance) + log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + + # compute kl between two normals + kl = 0.5 * ( + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + ) + total_kl += kl + if save_intermediates: + intermediates.append(kl.cpu()) + total_kl = total_kl.view(total_kl.shape[0], -1).sum(axis=1) + log_likelihood_per_dim = -total_kl / np.prod(inputs.shape[1:]) + if save_intermediates: + return log_likelihood_per_dim, intermediates + else: + return log_likelihood_per_dim + class LatentDiffusionInferer(DiffusionInferer): """ From b803356eb32a4296a5f3e5134c8dfa1b2cc507c5 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 5 Dec 2022 14:33:17 -0600 Subject: [PATCH 04/42] Fixes bug in get_mean --- generative/schedulers/ddpm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generative/schedulers/ddpm.py b/generative/schedulers/ddpm.py index 510c6438..e15f32c8 100644 --- a/generative/schedulers/ddpm.py +++ b/generative/schedulers/ddpm.py @@ -115,11 +115,12 @@ def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torc """ # these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0), # (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf) + alpha_t = self.alphas[timestep] alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t) - x_t_coefficient = alpha_prod_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) + x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) mean = x_0_coefficient * x_0 + x_t_coefficient * x_t @@ -143,7 +144,6 @@ def _get_variance(self, timestep: int, predicted_variance: Optional[torch.Tensor # and sample from it to get previous sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep] - # hacks - were probably added for training stability if self.variance_type == "fixed_small": variance = torch.clamp(variance, min=1e-20) From 60ad56c27795e5e08622be9228c56418cdbf939c Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 5 Dec 2022 14:34:17 -0600 Subject: [PATCH 05/42] Calculates mean/var from epsilon --- generative/inferers/inferer.py | 52 +++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 11 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 044f0475..5829700e 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -12,7 +12,6 @@ from typing import Callable, List, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn from monai.inferers import Inferer @@ -107,6 +106,7 @@ def get_likelihood( inputs: torch.Tensor, diffusion_model: Callable[..., torch.Tensor], scheduler: Optional[Callable[..., torch.Tensor]] = None, + predict_epsilon: bool = True, save_intermediates: Optional[bool] = False, conditioning: Optional[torch.Tensor] = None, verbose: Optional[bool] = True, @@ -118,7 +118,9 @@ def get_likelihood( inputs: input images, NxCxHxW[xD] diffusion_model: model to compute likelihood from scheduler: diffusion scheduler. If none provided will use the class attribute scheduler - save_intermediates: save the + predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. + + save_intermediates: save the intermediate spatial KL maps conditioning: verbose: if true, prints the progression bar of the sampling process. """ @@ -136,20 +138,49 @@ def get_likelihood( progress_bar = iter(scheduler.timesteps) intermediates = [] noise = torch.randn_like(inputs).to(inputs.device) - total_kl = torch.zeros_like(inputs) + total_kl = torch.zeros((inputs.shape[0])).to(inputs.device) for t in progress_bar: timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) - # get the model's predicted mean and variance if it is predicted + # get the model's predicted mean, and variance if it is predicted if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: - predicted_mean, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) + model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) else: - predicted_mean = model_output predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[t] + alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if predict_epsilon: + pred_original_sample = (model_output - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + else: + pred_original_sample = model_output + + # 3. Clip "predicted x_0" + if scheduler.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t + current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image + # get the posterior mean and variance posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) + # at t=0 variance is 0 and the log-variance blows up, fix this + if t == 0: + posterior_variance = torch.Tensor([1]).to(posterior_mean.device) log_posterior_variance = torch.log(posterior_variance) log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance @@ -161,15 +192,14 @@ def get_likelihood( + torch.exp(log_posterior_variance - log_predicted_variance) + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) ) - total_kl += kl + total_kl += kl.view(kl.shape[0], -1).mean(axis=1) if save_intermediates: intermediates.append(kl.cpu()) - total_kl = total_kl.view(total_kl.shape[0], -1).sum(axis=1) - log_likelihood_per_dim = -total_kl / np.prod(inputs.shape[1:]) + if save_intermediates: - return log_likelihood_per_dim, intermediates + return total_kl, intermediates else: - return log_likelihood_per_dim + return total_kl class LatentDiffusionInferer(DiffusionInferer): From d19185a51977b5aa58d056bd5122075bd94752cf Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 9 Dec 2022 12:57:12 -0600 Subject: [PATCH 06/42] Fixes bug in predicting input from noise --- generative/inferers/inferer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 5829700e..f0e087f1 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -158,7 +158,7 @@ def get_likelihood( # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if predict_epsilon: - pred_original_sample = (model_output - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) else: pred_original_sample = model_output From b6b13ab8c18cb7ee518f59344326faa44b5aa1e8 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 12 Dec 2022 11:26:02 -0600 Subject: [PATCH 07/42] Adds decoder log-likelihood --- generative/inferers/inferer.py | 73 ++++++++++++++++++++++++++++++---- 1 file changed, 65 insertions(+), 8 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index f0e087f1..ec0c7d7f 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -10,6 +10,7 @@ # limitations under the License. +import math from typing import Callable, List, Optional, Tuple, Union import torch @@ -140,6 +141,7 @@ def get_likelihood( noise = torch.randn_like(inputs).to(inputs.device) total_kl = torch.zeros((inputs.shape[0])).to(inputs.device) for t in progress_bar: + timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) @@ -184,14 +186,18 @@ def get_likelihood( log_posterior_variance = torch.log(posterior_variance) log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance - # compute kl between two normals - kl = 0.5 * ( - -1.0 - + log_predicted_variance - - log_posterior_variance - + torch.exp(log_posterior_variance - log_predicted_variance) - + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) - ) + if t == 0: + # compute -log p(x_0|x_1) + kl = -self._get_decoder_log_likelihood(inputs, predicted_mean, 0.5 * log_predicted_variance) + else: + # compute kl between two normals + kl = 0.5 * ( + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + ) total_kl += kl.view(kl.shape[0], -1).mean(axis=1) if save_intermediates: intermediates.append(kl.cpu()) @@ -201,6 +207,57 @@ def get_likelihood( else: return total_kl + def _approx_standard_normal_cdf(self, x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. Code adapted from https://github.com/openai/improved-diffusion. + """ + + return 0.5 * ( + 1.0 + torch.tanh(torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3))) + ) + + def _get_decoder_log_likelihood( + self, + inputs: torch.Tensor, + means: torch.Tensor, + log_scales: torch.Tensor, + original_input_range: Optional[Tuple] = [0, 255], + scaled_input_range: Optional[Tuple] = [0, 1], + ) -> torch.Tensor: + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. Code adapted from https://github.com/openai/improved-diffusion. + + Args: + input: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + means: the Gaussian mean Tensor. + log_scales: the Gaussian log stddev Tensor. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + """ + assert inputs.shape == means.shape + bin_width = (scaled_input_range[1] - scaled_input_range[0]) / ( + original_input_range[1] - original_input_range[0] + ) + centered_x = inputs - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + bin_width / 2) + cdf_plus = self._approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - bin_width / 2) + cdf_min = self._approx_standard_normal_cdf(min_in) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + inputs < -0.999, + log_cdf_plus, + torch.where(inputs > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == inputs.shape + return log_probs + class LatentDiffusionInferer(DiffusionInferer): """ From efc4020acdbbf45e9eaf609a58ec530ce011bb9f Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 12 Dec 2022 11:41:15 -0600 Subject: [PATCH 08/42] Adds log-likelihood calculation for latent diffusion model --- generative/inferers/inferer.py | 49 +++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index ec0c7d7f..90e23498 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -110,6 +110,8 @@ def get_likelihood( predict_epsilon: bool = True, save_intermediates: Optional[bool] = False, conditioning: Optional[torch.Tensor] = None, + original_input_range: Optional[Tuple] = [0, 255], + scaled_input_range: Optional[Tuple] = [0, 1], verbose: Optional[bool] = True, ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: """ @@ -122,7 +124,9 @@ def get_likelihood( predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. save_intermediates: save the intermediate spatial KL maps - conditioning: + conditioning: Conditioning for network input. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. verbose: if true, prints the progression bar of the sampling process. """ @@ -355,3 +359,46 @@ def sample( else: return image + + def get_likelihood( + self, + inputs: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + scheduler: Optional[Callable[..., torch.Tensor]] = None, + predict_epsilon: bool = True, + save_intermediates: Optional[bool] = False, + conditioning: Optional[torch.Tensor] = None, + original_input_range: Optional[Tuple] = [0, 255], + scaled_input_range: Optional[Tuple] = [0, 1], + verbose: Optional[bool] = True, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: + """ + Computes the likelihoods for an input. + + Args: + inputs: input images, NxCxHxW[xD] + autoencoder_model: first stage model. + diffusion_model: model to compute likelihood from + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. + + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + """ + + with torch.no_grad(): + latents = autoencoder_model.encode_stage_2_outputs(inputs) * self.scale_factor + outputs = super().get_likelihood( + inputs=latents, + diffusion_model=diffusion_model, + scheduler=scheduler, + predict_epsilon=predict_epsilon, + save_intermediates=save_intermediates, + conditioning=conditioning, + verbose=verbose, + ) + return outputs From 7422fbf678cd86848477a66dc7d07baf6858f80b Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 2 Dec 2022 10:36:19 -0600 Subject: [PATCH 09/42] Fixes return type in sample --- generative/inferers/inferer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index d7918e79..13a07e8d 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -66,7 +66,7 @@ def sample( intermediate_steps: Optional[int] = 100, conditioning: Optional[torch.Tensor] = None, verbose: Optional[bool] = True, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: """ Args: input_noise: random noise, of the same shape as the desired sample. From 17b79740c4463ee44f54d04c10d6d3fdf4f7e954 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 2 Dec 2022 17:14:07 -0600 Subject: [PATCH 10/42] Adds method to compute posterior mean --- generative/networks/schedulers/ddpm.py | 28 ++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index f2ea54e2..510c6438 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -82,7 +82,7 @@ def __init__( self.clip_sample = clip_sample self.variance_type = variance_type - # setable values + # settable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) @@ -101,9 +101,33 @@ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, to ].copy() self.timesteps = torch.from_numpy(timesteps).to(device) + def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: + """ + Compute the mean of the posterior at timestep t. + + Args: + timestep: current timestep. + x0: the noise-free input. + x_t: the input noised to timestep t. + + Returns: + Returns the mean + """ + # these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0), + # (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf) + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + + x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t) + x_t_coefficient = alpha_prod_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) + + mean = x_0_coefficient * x_0 + x_t_coefficient * x_t + + return mean + def _get_variance(self, timestep: int, predicted_variance: Optional[torch.Tensor] = None) -> torch.Tensor: """ - Compute the variance. + Compute the variance of the posterior at timestep t. Args: timestep: current timestep. From 440b1db14887ac0c4adaa7e685509235503c4566 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 2 Dec 2022 17:14:44 -0600 Subject: [PATCH 11/42] Initial code for computing likelihood --- generative/inferers/inferer.py | 70 ++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 13a07e8d..044f0475 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -12,6 +12,7 @@ from typing import Callable, List, Optional, Tuple, Union +import numpy as np import torch import torch.nn as nn from monai.inferers import Inferer @@ -101,6 +102,75 @@ def sample( else: return image + def get_likelihood( + self, + inputs: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + scheduler: Optional[Callable[..., torch.Tensor]] = None, + save_intermediates: Optional[bool] = False, + conditioning: Optional[torch.Tensor] = None, + verbose: Optional[bool] = True, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: + """ + Computes the likelihoods for an input. + + Args: + inputs: input images, NxCxHxW[xD] + diffusion_model: model to compute likelihood from + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: save the + conditioning: + verbose: if true, prints the progression bar of the sampling process. + """ + + if not scheduler: + scheduler = self.scheduler + if scheduler._get_name() != "DDPMScheduler": + raise NotImplementedError( + f"Likelihood computation is only compatible with DDPMScheduler," + f" you are using {scheduler._get_name()}" + ) + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + noise = torch.randn_like(inputs).to(inputs.device) + total_kl = torch.zeros_like(inputs) + for t in progress_bar: + timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) + # get the model's predicted mean and variance if it is predicted + if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + predicted_mean, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) + else: + predicted_mean = model_output + predicted_variance = None + # get the posterior mean and variance + posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) + log_posterior_variance = torch.log(posterior_variance) + log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + + # compute kl between two normals + kl = 0.5 * ( + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + ) + total_kl += kl + if save_intermediates: + intermediates.append(kl.cpu()) + total_kl = total_kl.view(total_kl.shape[0], -1).sum(axis=1) + log_likelihood_per_dim = -total_kl / np.prod(inputs.shape[1:]) + if save_intermediates: + return log_likelihood_per_dim, intermediates + else: + return log_likelihood_per_dim + class LatentDiffusionInferer(DiffusionInferer): """ From b0b07ecdc15e6658070429c15a94fc3bf3b1cca6 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 5 Dec 2022 14:33:17 -0600 Subject: [PATCH 12/42] Fixes bug in get_mean --- generative/networks/schedulers/ddpm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index 510c6438..e15f32c8 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -115,11 +115,12 @@ def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torc """ # these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0), # (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf) + alpha_t = self.alphas[timestep] alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t) - x_t_coefficient = alpha_prod_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) + x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) mean = x_0_coefficient * x_0 + x_t_coefficient * x_t @@ -143,7 +144,6 @@ def _get_variance(self, timestep: int, predicted_variance: Optional[torch.Tensor # and sample from it to get previous sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep] - # hacks - were probably added for training stability if self.variance_type == "fixed_small": variance = torch.clamp(variance, min=1e-20) From 86051de1267b7965e298c21d799a3106539018f6 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 5 Dec 2022 14:34:17 -0600 Subject: [PATCH 13/42] Calculates mean/var from epsilon --- generative/inferers/inferer.py | 52 +++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 11 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 044f0475..5829700e 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -12,7 +12,6 @@ from typing import Callable, List, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn from monai.inferers import Inferer @@ -107,6 +106,7 @@ def get_likelihood( inputs: torch.Tensor, diffusion_model: Callable[..., torch.Tensor], scheduler: Optional[Callable[..., torch.Tensor]] = None, + predict_epsilon: bool = True, save_intermediates: Optional[bool] = False, conditioning: Optional[torch.Tensor] = None, verbose: Optional[bool] = True, @@ -118,7 +118,9 @@ def get_likelihood( inputs: input images, NxCxHxW[xD] diffusion_model: model to compute likelihood from scheduler: diffusion scheduler. If none provided will use the class attribute scheduler - save_intermediates: save the + predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. + + save_intermediates: save the intermediate spatial KL maps conditioning: verbose: if true, prints the progression bar of the sampling process. """ @@ -136,20 +138,49 @@ def get_likelihood( progress_bar = iter(scheduler.timesteps) intermediates = [] noise = torch.randn_like(inputs).to(inputs.device) - total_kl = torch.zeros_like(inputs) + total_kl = torch.zeros((inputs.shape[0])).to(inputs.device) for t in progress_bar: timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) - # get the model's predicted mean and variance if it is predicted + # get the model's predicted mean, and variance if it is predicted if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: - predicted_mean, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) + model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) else: - predicted_mean = model_output predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[t] + alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if predict_epsilon: + pred_original_sample = (model_output - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + else: + pred_original_sample = model_output + + # 3. Clip "predicted x_0" + if scheduler.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t + current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image + # get the posterior mean and variance posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) + # at t=0 variance is 0 and the log-variance blows up, fix this + if t == 0: + posterior_variance = torch.Tensor([1]).to(posterior_mean.device) log_posterior_variance = torch.log(posterior_variance) log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance @@ -161,15 +192,14 @@ def get_likelihood( + torch.exp(log_posterior_variance - log_predicted_variance) + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) ) - total_kl += kl + total_kl += kl.view(kl.shape[0], -1).mean(axis=1) if save_intermediates: intermediates.append(kl.cpu()) - total_kl = total_kl.view(total_kl.shape[0], -1).sum(axis=1) - log_likelihood_per_dim = -total_kl / np.prod(inputs.shape[1:]) + if save_intermediates: - return log_likelihood_per_dim, intermediates + return total_kl, intermediates else: - return log_likelihood_per_dim + return total_kl class LatentDiffusionInferer(DiffusionInferer): From 0f499ec7fd83da612622ed044dab3c0ea301efa8 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 9 Dec 2022 12:57:12 -0600 Subject: [PATCH 14/42] Fixes bug in predicting input from noise --- generative/inferers/inferer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 5829700e..f0e087f1 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -158,7 +158,7 @@ def get_likelihood( # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if predict_epsilon: - pred_original_sample = (model_output - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) else: pred_original_sample = model_output From 1369812f4050f1f9c75368391c7f067078a7cfa2 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 12 Dec 2022 11:26:02 -0600 Subject: [PATCH 15/42] Adds decoder log-likelihood --- generative/inferers/inferer.py | 73 ++++++++++++++++++++++++++++++---- 1 file changed, 65 insertions(+), 8 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index f0e087f1..ec0c7d7f 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -10,6 +10,7 @@ # limitations under the License. +import math from typing import Callable, List, Optional, Tuple, Union import torch @@ -140,6 +141,7 @@ def get_likelihood( noise = torch.randn_like(inputs).to(inputs.device) total_kl = torch.zeros((inputs.shape[0])).to(inputs.device) for t in progress_bar: + timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) @@ -184,14 +186,18 @@ def get_likelihood( log_posterior_variance = torch.log(posterior_variance) log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance - # compute kl between two normals - kl = 0.5 * ( - -1.0 - + log_predicted_variance - - log_posterior_variance - + torch.exp(log_posterior_variance - log_predicted_variance) - + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) - ) + if t == 0: + # compute -log p(x_0|x_1) + kl = -self._get_decoder_log_likelihood(inputs, predicted_mean, 0.5 * log_predicted_variance) + else: + # compute kl between two normals + kl = 0.5 * ( + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + ) total_kl += kl.view(kl.shape[0], -1).mean(axis=1) if save_intermediates: intermediates.append(kl.cpu()) @@ -201,6 +207,57 @@ def get_likelihood( else: return total_kl + def _approx_standard_normal_cdf(self, x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. Code adapted from https://github.com/openai/improved-diffusion. + """ + + return 0.5 * ( + 1.0 + torch.tanh(torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3))) + ) + + def _get_decoder_log_likelihood( + self, + inputs: torch.Tensor, + means: torch.Tensor, + log_scales: torch.Tensor, + original_input_range: Optional[Tuple] = [0, 255], + scaled_input_range: Optional[Tuple] = [0, 1], + ) -> torch.Tensor: + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. Code adapted from https://github.com/openai/improved-diffusion. + + Args: + input: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + means: the Gaussian mean Tensor. + log_scales: the Gaussian log stddev Tensor. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + """ + assert inputs.shape == means.shape + bin_width = (scaled_input_range[1] - scaled_input_range[0]) / ( + original_input_range[1] - original_input_range[0] + ) + centered_x = inputs - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + bin_width / 2) + cdf_plus = self._approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - bin_width / 2) + cdf_min = self._approx_standard_normal_cdf(min_in) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + inputs < -0.999, + log_cdf_plus, + torch.where(inputs > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == inputs.shape + return log_probs + class LatentDiffusionInferer(DiffusionInferer): """ From 177d6928c3d14c6446a0b06630034b060e76b265 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 12 Dec 2022 11:41:15 -0600 Subject: [PATCH 16/42] Adds log-likelihood calculation for latent diffusion model --- generative/inferers/inferer.py | 49 +++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index ec0c7d7f..90e23498 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -110,6 +110,8 @@ def get_likelihood( predict_epsilon: bool = True, save_intermediates: Optional[bool] = False, conditioning: Optional[torch.Tensor] = None, + original_input_range: Optional[Tuple] = [0, 255], + scaled_input_range: Optional[Tuple] = [0, 1], verbose: Optional[bool] = True, ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: """ @@ -122,7 +124,9 @@ def get_likelihood( predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. save_intermediates: save the intermediate spatial KL maps - conditioning: + conditioning: Conditioning for network input. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. verbose: if true, prints the progression bar of the sampling process. """ @@ -355,3 +359,46 @@ def sample( else: return image + + def get_likelihood( + self, + inputs: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + scheduler: Optional[Callable[..., torch.Tensor]] = None, + predict_epsilon: bool = True, + save_intermediates: Optional[bool] = False, + conditioning: Optional[torch.Tensor] = None, + original_input_range: Optional[Tuple] = [0, 255], + scaled_input_range: Optional[Tuple] = [0, 1], + verbose: Optional[bool] = True, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: + """ + Computes the likelihoods for an input. + + Args: + inputs: input images, NxCxHxW[xD] + autoencoder_model: first stage model. + diffusion_model: model to compute likelihood from + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. + + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + """ + + with torch.no_grad(): + latents = autoencoder_model.encode_stage_2_outputs(inputs) * self.scale_factor + outputs = super().get_likelihood( + inputs=latents, + diffusion_model=diffusion_model, + scheduler=scheduler, + predict_epsilon=predict_epsilon, + save_intermediates=save_intermediates, + conditioning=conditioning, + verbose=verbose, + ) + return outputs From 547d2479b7e9a6f38542c53c343f5d6774543e66 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 12 Dec 2022 12:06:44 -0600 Subject: [PATCH 17/42] Adds tests --- tests/test_diffusion_inferer.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py index 0ceb71ef..95b6a26c 100644 --- a/tests/test_diffusion_inferer.py +++ b/tests/test_diffusion_inferer.py @@ -140,6 +140,37 @@ def test_sampler_conditioned(self, model_params, input_shape): ) self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) + def test_get_likelihood(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler( + num_train_timesteps=10, + ) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + likelihood, intermediates = inferer.get_likelihood( + inputs=input, diffusion_model=model, scheduler=scheduler, save_intermediates=True + ) + self.assertEqual(intermediates[0].shape, input.shape) + self.assertEqual(likelihood.shape[0], input.shape[0]) + + def test_normal_cdf(self): + from scipy.stats import norm + + scheduler = DDPMScheduler( + num_train_timesteps=10, + ) + inferer = DiffusionInferer(scheduler=scheduler) + + x = torch.linspace(-10, 10, 20) + cdf_approx = inferer._approx_standard_normal_cdf(x) + cdf_true = norm.cdf(x) + torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + if __name__ == "__main__": unittest.main() From 043dda7e493bd958290d2397625662759213634c Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 12 Dec 2022 12:15:35 -0600 Subject: [PATCH 18/42] Adds latent tests --- generative/inferers/inferer.py | 2 +- tests/test_latent_diffusion_inferer.py | 28 ++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 90e23498..6ab8c49c 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -391,7 +391,7 @@ def get_likelihood( """ with torch.no_grad(): - latents = autoencoder_model.encode_stage_2_outputs(inputs) * self.scale_factor + latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor outputs = super().get_likelihood( inputs=latents, diffusion_model=diffusion_model, diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index 81cb8002..3f9c81ba 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -153,6 +153,34 @@ def test_sample_intermediates(self, model_type, autoencoder_params, stage_2_para self.assertEqual(len(intermediates), 10) self.assertEqual(intermediates[0].shape, input_shape) + @parameterized.expand(TEST_CASES) + def test_get_likelihoods(self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape): + if model_type == "AutoencoderKL": + autoencoder_model = AutoencoderKL(**autoencoder_params) + if model_type == "VQVAE": + autoencoder_model = VQVAE(**autoencoder_params) + stage_2 = DiffusionModelUNet(**stage_2_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + autoencoder_model.to(device) + stage_2.to(device) + autoencoder_model.eval() + autoencoder_model.train() + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler( + num_train_timesteps=10, + ) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=autoencoder_model, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, latent_shape) + if __name__ == "__main__": unittest.main() From 14dc49651c183125ccc4750202b520976320c273 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 13 Dec 2022 11:58:04 -0600 Subject: [PATCH 19/42] Pass input scalings to decoder calc --- generative/inferers/inferer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 6ab8c49c..90dbd552 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -184,15 +184,19 @@ def get_likelihood( # get the posterior mean and variance posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) - # at t=0 variance is 0 and the log-variance blows up, fix this - if t == 0: - posterior_variance = torch.Tensor([1]).to(posterior_mean.device) + log_posterior_variance = torch.log(posterior_variance) log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance if t == 0: # compute -log p(x_0|x_1) - kl = -self._get_decoder_log_likelihood(inputs, predicted_mean, 0.5 * log_predicted_variance) + kl = -self._get_decoder_log_likelihood( + inputs=inputs, + means=predicted_mean, + log_scales=0.5 * log_predicted_variance, + original_input_range=original_input_range, + scaled_input_range=scaled_input_range, + ) else: # compute kl between two normals kl = 0.5 * ( From fedf06693c1132d77a32a6d1837f6e80bf248778 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 13 Dec 2022 12:02:47 -0600 Subject: [PATCH 20/42] Fix arg and docstring --- generative/inferers/inferer.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 90dbd552..3ccad90f 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -110,8 +110,8 @@ def get_likelihood( predict_epsilon: bool = True, save_intermediates: Optional[bool] = False, conditioning: Optional[torch.Tensor] = None, - original_input_range: Optional[Tuple] = [0, 255], - scaled_input_range: Optional[Tuple] = [0, 1], + original_input_range: Optional[Tuple] = (0, 255), + scaled_input_range: Optional[Tuple] = (0, 1), verbose: Optional[bool] = True, ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: """ @@ -120,9 +120,8 @@ def get_likelihood( Args: inputs: input images, NxCxHxW[xD] diffusion_model: model to compute likelihood from - scheduler: diffusion scheduler. If none provided will use the class attribute scheduler - predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. - + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. save_intermediates: save the intermediate spatial KL maps conditioning: Conditioning for network input. original_input_range: the [min,max] intensity range of the input data before any scaling was applied. @@ -145,7 +144,6 @@ def get_likelihood( noise = torch.randn_like(inputs).to(inputs.device) total_kl = torch.zeros((inputs.shape[0])).to(inputs.device) for t in progress_bar: - timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) @@ -373,8 +371,8 @@ def get_likelihood( predict_epsilon: bool = True, save_intermediates: Optional[bool] = False, conditioning: Optional[torch.Tensor] = None, - original_input_range: Optional[Tuple] = [0, 255], - scaled_input_range: Optional[Tuple] = [0, 1], + original_input_range: Optional[Tuple] = (0, 255), + scaled_input_range: Optional[Tuple] = (0, 1), verbose: Optional[bool] = True, ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: """ @@ -385,13 +383,12 @@ def get_likelihood( autoencoder_model: first stage model. diffusion_model: model to compute likelihood from scheduler: diffusion scheduler. If none provided will use the class attribute scheduler - predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. - + predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. save_intermediates: save the intermediate spatial KL maps conditioning: Conditioning for network input. original_input_range: the [min,max] intensity range of the input data before any scaling was applied. scaled_input_range: the [min,max] intensity range of the input data after scaling. - verbose: if true, prints the progression bar of the sampling process. + verbose: if tguarue, prints the progression bar of the sampling process. """ with torch.no_grad(): From fa8031dd4c53058a25f2877fe4a83b53bc997cd8 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 2 Dec 2022 10:36:19 -0600 Subject: [PATCH 21/42] Fixes return type in sample --- generative/inferers/inferer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 9785c585..40d1fff9 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -66,7 +66,7 @@ def sample( intermediate_steps: Optional[int] = 100, conditioning: Optional[torch.Tensor] = None, verbose: Optional[bool] = True, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: """ Args: input_noise: random noise, of the same shape as the desired sample. From 1164b5374fb459ef709b35d58bc1abbe2bede64d Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 2 Dec 2022 17:14:07 -0600 Subject: [PATCH 22/42] Adds method to compute posterior mean --- generative/networks/schedulers/ddpm.py | 28 ++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index f0a64482..71b96c32 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -90,7 +90,7 @@ def __init__( self.clip_sample = clip_sample self.variance_type = variance_type - # setable values + # settable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) @@ -109,9 +109,33 @@ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, to ].copy() self.timesteps = torch.from_numpy(timesteps).to(device) + def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: + """ + Compute the mean of the posterior at timestep t. + + Args: + timestep: current timestep. + x0: the noise-free input. + x_t: the input noised to timestep t. + + Returns: + Returns the mean + """ + # these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0), + # (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf) + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + + x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t) + x_t_coefficient = alpha_prod_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) + + mean = x_0_coefficient * x_0 + x_t_coefficient * x_t + + return mean + def _get_variance(self, timestep: int, predicted_variance: Optional[torch.Tensor] = None) -> torch.Tensor: """ - Compute the variance. + Compute the variance of the posterior at timestep t. Args: timestep: current timestep. From 91a28d2a661729b29bcb65b3d9cfafb7085987bf Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 2 Dec 2022 17:14:44 -0600 Subject: [PATCH 23/42] Initial code for computing likelihood --- generative/inferers/inferer.py | 70 ++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 40d1fff9..f06a48b0 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -12,6 +12,7 @@ from typing import Callable, List, Optional, Tuple, Union +import numpy as np import torch import torch.nn as nn from monai.inferers import Inferer @@ -101,6 +102,75 @@ def sample( else: return image + def get_likelihood( + self, + inputs: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + scheduler: Optional[Callable[..., torch.Tensor]] = None, + save_intermediates: Optional[bool] = False, + conditioning: Optional[torch.Tensor] = None, + verbose: Optional[bool] = True, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: + """ + Computes the likelihoods for an input. + + Args: + inputs: input images, NxCxHxW[xD] + diffusion_model: model to compute likelihood from + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: save the + conditioning: + verbose: if true, prints the progression bar of the sampling process. + """ + + if not scheduler: + scheduler = self.scheduler + if scheduler._get_name() != "DDPMScheduler": + raise NotImplementedError( + f"Likelihood computation is only compatible with DDPMScheduler," + f" you are using {scheduler._get_name()}" + ) + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + noise = torch.randn_like(inputs).to(inputs.device) + total_kl = torch.zeros_like(inputs) + for t in progress_bar: + timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) + # get the model's predicted mean and variance if it is predicted + if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + predicted_mean, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) + else: + predicted_mean = model_output + predicted_variance = None + # get the posterior mean and variance + posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) + log_posterior_variance = torch.log(posterior_variance) + log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + + # compute kl between two normals + kl = 0.5 * ( + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + ) + total_kl += kl + if save_intermediates: + intermediates.append(kl.cpu()) + total_kl = total_kl.view(total_kl.shape[0], -1).sum(axis=1) + log_likelihood_per_dim = -total_kl / np.prod(inputs.shape[1:]) + if save_intermediates: + return log_likelihood_per_dim, intermediates + else: + return log_likelihood_per_dim + class LatentDiffusionInferer(DiffusionInferer): """ From 9f217f42a863c665d558493b72ef470f6684637f Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 5 Dec 2022 14:33:17 -0600 Subject: [PATCH 24/42] Fixes bug in get_mean --- generative/networks/schedulers/ddpm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index 71b96c32..9f5ca107 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -123,11 +123,12 @@ def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torc """ # these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0), # (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf) + alpha_t = self.alphas[timestep] alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t) - x_t_coefficient = alpha_prod_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) + x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) mean = x_0_coefficient * x_0 + x_t_coefficient * x_t @@ -151,7 +152,6 @@ def _get_variance(self, timestep: int, predicted_variance: Optional[torch.Tensor # and sample from it to get previous sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep] - # hacks - were probably added for training stability if self.variance_type == "fixed_small": variance = torch.clamp(variance, min=1e-20) From 64c62d7b122fabf0011dc37ed0e7f603f7107cd9 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 5 Dec 2022 14:34:17 -0600 Subject: [PATCH 25/42] Calculates mean/var from epsilon --- generative/inferers/inferer.py | 52 +++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 11 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index f06a48b0..94db6282 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -12,7 +12,6 @@ from typing import Callable, List, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn from monai.inferers import Inferer @@ -107,6 +106,7 @@ def get_likelihood( inputs: torch.Tensor, diffusion_model: Callable[..., torch.Tensor], scheduler: Optional[Callable[..., torch.Tensor]] = None, + predict_epsilon: bool = True, save_intermediates: Optional[bool] = False, conditioning: Optional[torch.Tensor] = None, verbose: Optional[bool] = True, @@ -118,7 +118,9 @@ def get_likelihood( inputs: input images, NxCxHxW[xD] diffusion_model: model to compute likelihood from scheduler: diffusion scheduler. If none provided will use the class attribute scheduler - save_intermediates: save the + predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. + + save_intermediates: save the intermediate spatial KL maps conditioning: verbose: if true, prints the progression bar of the sampling process. """ @@ -136,20 +138,49 @@ def get_likelihood( progress_bar = iter(scheduler.timesteps) intermediates = [] noise = torch.randn_like(inputs).to(inputs.device) - total_kl = torch.zeros_like(inputs) + total_kl = torch.zeros((inputs.shape[0])).to(inputs.device) for t in progress_bar: timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) - # get the model's predicted mean and variance if it is predicted + # get the model's predicted mean, and variance if it is predicted if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: - predicted_mean, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) + model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) else: - predicted_mean = model_output predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[t] + alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if predict_epsilon: + pred_original_sample = (model_output - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + else: + pred_original_sample = model_output + + # 3. Clip "predicted x_0" + if scheduler.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t + current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image + # get the posterior mean and variance posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) + # at t=0 variance is 0 and the log-variance blows up, fix this + if t == 0: + posterior_variance = torch.Tensor([1]).to(posterior_mean.device) log_posterior_variance = torch.log(posterior_variance) log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance @@ -161,15 +192,14 @@ def get_likelihood( + torch.exp(log_posterior_variance - log_predicted_variance) + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) ) - total_kl += kl + total_kl += kl.view(kl.shape[0], -1).mean(axis=1) if save_intermediates: intermediates.append(kl.cpu()) - total_kl = total_kl.view(total_kl.shape[0], -1).sum(axis=1) - log_likelihood_per_dim = -total_kl / np.prod(inputs.shape[1:]) + if save_intermediates: - return log_likelihood_per_dim, intermediates + return total_kl, intermediates else: - return log_likelihood_per_dim + return total_kl class LatentDiffusionInferer(DiffusionInferer): From 4ac27589df61a82a81ee6dfa40a719bfafdca391 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 9 Dec 2022 12:57:12 -0600 Subject: [PATCH 26/42] Fixes bug in predicting input from noise --- generative/inferers/inferer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 94db6282..e07bc8cb 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -158,7 +158,7 @@ def get_likelihood( # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if predict_epsilon: - pred_original_sample = (model_output - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) else: pred_original_sample = model_output From 9fbf96240d3e97dbbba0f813907937a61a9b05ea Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 12 Dec 2022 11:26:02 -0600 Subject: [PATCH 27/42] Adds decoder log-likelihood --- generative/inferers/inferer.py | 73 ++++++++++++++++++++++++++++++---- 1 file changed, 65 insertions(+), 8 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index e07bc8cb..06b868fa 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -10,6 +10,7 @@ # limitations under the License. +import math from typing import Callable, List, Optional, Tuple, Union import torch @@ -140,6 +141,7 @@ def get_likelihood( noise = torch.randn_like(inputs).to(inputs.device) total_kl = torch.zeros((inputs.shape[0])).to(inputs.device) for t in progress_bar: + timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) @@ -184,14 +186,18 @@ def get_likelihood( log_posterior_variance = torch.log(posterior_variance) log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance - # compute kl between two normals - kl = 0.5 * ( - -1.0 - + log_predicted_variance - - log_posterior_variance - + torch.exp(log_posterior_variance - log_predicted_variance) - + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) - ) + if t == 0: + # compute -log p(x_0|x_1) + kl = -self._get_decoder_log_likelihood(inputs, predicted_mean, 0.5 * log_predicted_variance) + else: + # compute kl between two normals + kl = 0.5 * ( + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + ) total_kl += kl.view(kl.shape[0], -1).mean(axis=1) if save_intermediates: intermediates.append(kl.cpu()) @@ -201,6 +207,57 @@ def get_likelihood( else: return total_kl + def _approx_standard_normal_cdf(self, x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. Code adapted from https://github.com/openai/improved-diffusion. + """ + + return 0.5 * ( + 1.0 + torch.tanh(torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3))) + ) + + def _get_decoder_log_likelihood( + self, + inputs: torch.Tensor, + means: torch.Tensor, + log_scales: torch.Tensor, + original_input_range: Optional[Tuple] = [0, 255], + scaled_input_range: Optional[Tuple] = [0, 1], + ) -> torch.Tensor: + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. Code adapted from https://github.com/openai/improved-diffusion. + + Args: + input: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + means: the Gaussian mean Tensor. + log_scales: the Gaussian log stddev Tensor. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + """ + assert inputs.shape == means.shape + bin_width = (scaled_input_range[1] - scaled_input_range[0]) / ( + original_input_range[1] - original_input_range[0] + ) + centered_x = inputs - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + bin_width / 2) + cdf_plus = self._approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - bin_width / 2) + cdf_min = self._approx_standard_normal_cdf(min_in) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + inputs < -0.999, + log_cdf_plus, + torch.where(inputs > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == inputs.shape + return log_probs + class LatentDiffusionInferer(DiffusionInferer): """ From e45fad913c9cca0e5c85a81a0d8cd1357ab93565 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 12 Dec 2022 11:41:15 -0600 Subject: [PATCH 28/42] Adds log-likelihood calculation for latent diffusion model --- generative/inferers/inferer.py | 49 +++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 06b868fa..966ccbaf 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -110,6 +110,8 @@ def get_likelihood( predict_epsilon: bool = True, save_intermediates: Optional[bool] = False, conditioning: Optional[torch.Tensor] = None, + original_input_range: Optional[Tuple] = [0, 255], + scaled_input_range: Optional[Tuple] = [0, 1], verbose: Optional[bool] = True, ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: """ @@ -122,7 +124,9 @@ def get_likelihood( predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. save_intermediates: save the intermediate spatial KL maps - conditioning: + conditioning: Conditioning for network input. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. verbose: if true, prints the progression bar of the sampling process. """ @@ -358,3 +362,46 @@ def sample( else: return image + + def get_likelihood( + self, + inputs: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + scheduler: Optional[Callable[..., torch.Tensor]] = None, + predict_epsilon: bool = True, + save_intermediates: Optional[bool] = False, + conditioning: Optional[torch.Tensor] = None, + original_input_range: Optional[Tuple] = [0, 255], + scaled_input_range: Optional[Tuple] = [0, 1], + verbose: Optional[bool] = True, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: + """ + Computes the likelihoods for an input. + + Args: + inputs: input images, NxCxHxW[xD] + autoencoder_model: first stage model. + diffusion_model: model to compute likelihood from + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. + + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + """ + + with torch.no_grad(): + latents = autoencoder_model.encode_stage_2_outputs(inputs) * self.scale_factor + outputs = super().get_likelihood( + inputs=latents, + diffusion_model=diffusion_model, + scheduler=scheduler, + predict_epsilon=predict_epsilon, + save_intermediates=save_intermediates, + conditioning=conditioning, + verbose=verbose, + ) + return outputs From 432c1d0b5a5fe34879046b618097ad29164d4991 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 12 Dec 2022 12:06:44 -0600 Subject: [PATCH 29/42] Adds tests --- tests/test_diffusion_inferer.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py index 8b3cd511..74f725c4 100644 --- a/tests/test_diffusion_inferer.py +++ b/tests/test_diffusion_inferer.py @@ -141,6 +141,37 @@ def test_sampler_conditioned(self, model_params, input_shape): ) self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) + def test_get_likelihood(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler( + num_train_timesteps=10, + ) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + likelihood, intermediates = inferer.get_likelihood( + inputs=input, diffusion_model=model, scheduler=scheduler, save_intermediates=True + ) + self.assertEqual(intermediates[0].shape, input.shape) + self.assertEqual(likelihood.shape[0], input.shape[0]) + + def test_normal_cdf(self): + from scipy.stats import norm + + scheduler = DDPMScheduler( + num_train_timesteps=10, + ) + inferer = DiffusionInferer(scheduler=scheduler) + + x = torch.linspace(-10, 10, 20) + cdf_approx = inferer._approx_standard_normal_cdf(x) + cdf_true = norm.cdf(x) + torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + if __name__ == "__main__": unittest.main() From 92c7560c9488f9de4c9a4da5e6f7a47880a62c18 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 12 Dec 2022 12:15:35 -0600 Subject: [PATCH 30/42] Adds latent tests --- generative/inferers/inferer.py | 2 +- tests/test_latent_diffusion_inferer.py | 28 ++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 966ccbaf..1fc7b51d 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -394,7 +394,7 @@ def get_likelihood( """ with torch.no_grad(): - latents = autoencoder_model.encode_stage_2_outputs(inputs) * self.scale_factor + latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor outputs = super().get_likelihood( inputs=latents, diffusion_model=diffusion_model, diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index 450b940e..dc9e26ab 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -156,6 +156,34 @@ def test_sample_intermediates(self, model_type, autoencoder_params, stage_2_para self.assertEqual(len(intermediates), 10) self.assertEqual(intermediates[0].shape, input_shape) + @parameterized.expand(TEST_CASES) + def test_get_likelihoods(self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape): + if model_type == "AutoencoderKL": + autoencoder_model = AutoencoderKL(**autoencoder_params) + if model_type == "VQVAE": + autoencoder_model = VQVAE(**autoencoder_params) + stage_2 = DiffusionModelUNet(**stage_2_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + autoencoder_model.to(device) + stage_2.to(device) + autoencoder_model.eval() + autoencoder_model.train() + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler( + num_train_timesteps=10, + ) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=autoencoder_model, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, latent_shape) + if __name__ == "__main__": unittest.main() From 62b0f701d1d65d29cc9299f97b533d0d9f31af6e Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 2 Dec 2022 17:14:07 -0600 Subject: [PATCH 31/42] Adds method to compute posterior mean --- generative/networks/schedulers/ddpm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index 9f5ca107..9dbce696 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -123,6 +123,7 @@ def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torc """ # these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0), # (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf) + alpha_t = self.alphas[timestep] alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one @@ -130,6 +131,7 @@ def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torc x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t) x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) + mean = x_0_coefficient * x_0 + x_t_coefficient * x_t return mean From 3ed147a72fe90eaa0f03a860e55cfce383500e5e Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 2 Dec 2022 17:14:44 -0600 Subject: [PATCH 32/42] Initial code for computing likelihood --- generative/inferers/inferer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 1fc7b51d..9dc69f85 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -13,6 +13,7 @@ import math from typing import Callable, List, Optional, Tuple, Union +import numpy as np import torch import torch.nn as nn from monai.inferers import Inferer @@ -107,7 +108,6 @@ def get_likelihood( inputs: torch.Tensor, diffusion_model: Callable[..., torch.Tensor], scheduler: Optional[Callable[..., torch.Tensor]] = None, - predict_epsilon: bool = True, save_intermediates: Optional[bool] = False, conditioning: Optional[torch.Tensor] = None, original_input_range: Optional[Tuple] = [0, 255], @@ -121,8 +121,6 @@ def get_likelihood( inputs: input images, NxCxHxW[xD] diffusion_model: model to compute likelihood from scheduler: diffusion scheduler. If none provided will use the class attribute scheduler - predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. - save_intermediates: save the intermediate spatial KL maps conditioning: Conditioning for network input. original_input_range: the [min,max] intensity range of the input data before any scaling was applied. @@ -262,7 +260,6 @@ def _get_decoder_log_likelihood( assert log_probs.shape == inputs.shape return log_probs - class LatentDiffusionInferer(DiffusionInferer): """ LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can From 1e847e07bb62ab04402c9f2b70d770b0709ecc9d Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 5 Dec 2022 14:33:17 -0600 Subject: [PATCH 33/42] Fixes bug in get_mean --- generative/networks/schedulers/ddpm.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index 9dbce696..9f5ca107 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -123,7 +123,6 @@ def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torc """ # these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0), # (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf) - alpha_t = self.alphas[timestep] alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one @@ -131,7 +130,6 @@ def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torc x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t) x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) - mean = x_0_coefficient * x_0 + x_t_coefficient * x_t return mean From 6d7a0ad94a67005bebc5a338aaba5258268b6130 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 5 Dec 2022 14:34:17 -0600 Subject: [PATCH 34/42] Calculates mean/var from epsilon --- generative/inferers/inferer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 9dc69f85..f1fe647d 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -13,7 +13,6 @@ import math from typing import Callable, List, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn from monai.inferers import Inferer @@ -108,6 +107,7 @@ def get_likelihood( inputs: torch.Tensor, diffusion_model: Callable[..., torch.Tensor], scheduler: Optional[Callable[..., torch.Tensor]] = None, + predict_epsilon: bool = True, save_intermediates: Optional[bool] = False, conditioning: Optional[torch.Tensor] = None, original_input_range: Optional[Tuple] = [0, 255], @@ -162,7 +162,7 @@ def get_likelihood( # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if predict_epsilon: - pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_original_sample = (model_output - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) else: pred_original_sample = model_output @@ -200,6 +200,7 @@ def get_likelihood( + torch.exp(log_posterior_variance - log_predicted_variance) + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) ) + total_kl += kl.view(kl.shape[0], -1).mean(axis=1) if save_intermediates: intermediates.append(kl.cpu()) From bdbc303caebe955bcc73ce7b1516682433b4de78 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 9 Dec 2022 12:57:12 -0600 Subject: [PATCH 35/42] Fixes bug in predicting input from noise --- generative/inferers/inferer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index f1fe647d..ed09ad44 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -162,7 +162,7 @@ def get_likelihood( # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if predict_epsilon: - pred_original_sample = (model_output - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) else: pred_original_sample = model_output From 04f5c68f412957e1c34c99e8d34037177f56ef9e Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 12 Dec 2022 11:26:02 -0600 Subject: [PATCH 36/42] Adds decoder log-likelihood --- generative/inferers/inferer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index ed09ad44..e6a79fb9 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -200,7 +200,6 @@ def get_likelihood( + torch.exp(log_posterior_variance - log_predicted_variance) + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) ) - total_kl += kl.view(kl.shape[0], -1).mean(axis=1) if save_intermediates: intermediates.append(kl.cpu()) From e1b311088aa460632f3b1c21f466026b29734e18 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 13 Dec 2022 11:58:04 -0600 Subject: [PATCH 37/42] Pass input scalings to decoder calc --- generative/inferers/inferer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index e6a79fb9..4cea6db6 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -182,15 +182,19 @@ def get_likelihood( # get the posterior mean and variance posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) - # at t=0 variance is 0 and the log-variance blows up, fix this - if t == 0: - posterior_variance = torch.Tensor([1]).to(posterior_mean.device) + log_posterior_variance = torch.log(posterior_variance) log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance if t == 0: # compute -log p(x_0|x_1) - kl = -self._get_decoder_log_likelihood(inputs, predicted_mean, 0.5 * log_predicted_variance) + kl = -self._get_decoder_log_likelihood( + inputs=inputs, + means=predicted_mean, + log_scales=0.5 * log_predicted_variance, + original_input_range=original_input_range, + scaled_input_range=scaled_input_range, + ) else: # compute kl between two normals kl = 0.5 * ( From 1b73c205803ea2f87682f1e1564450c6e0ba4a5b Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 13 Dec 2022 12:02:47 -0600 Subject: [PATCH 38/42] Fix arg and docstring --- generative/inferers/inferer.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 4cea6db6..db495ca2 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -110,8 +110,8 @@ def get_likelihood( predict_epsilon: bool = True, save_intermediates: Optional[bool] = False, conditioning: Optional[torch.Tensor] = None, - original_input_range: Optional[Tuple] = [0, 255], - scaled_input_range: Optional[Tuple] = [0, 1], + original_input_range: Optional[Tuple] = (0, 255), + scaled_input_range: Optional[Tuple] = (0, 1), verbose: Optional[bool] = True, ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: """ @@ -120,7 +120,8 @@ def get_likelihood( Args: inputs: input images, NxCxHxW[xD] diffusion_model: model to compute likelihood from - scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. save_intermediates: save the intermediate spatial KL maps conditioning: Conditioning for network input. original_input_range: the [min,max] intensity range of the input data before any scaling was applied. @@ -143,7 +144,6 @@ def get_likelihood( noise = torch.randn_like(inputs).to(inputs.device) total_kl = torch.zeros((inputs.shape[0])).to(inputs.device) for t in progress_bar: - timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) @@ -373,8 +373,8 @@ def get_likelihood( predict_epsilon: bool = True, save_intermediates: Optional[bool] = False, conditioning: Optional[torch.Tensor] = None, - original_input_range: Optional[Tuple] = [0, 255], - scaled_input_range: Optional[Tuple] = [0, 1], + original_input_range: Optional[Tuple] = (0, 255), + scaled_input_range: Optional[Tuple] = (0, 1), verbose: Optional[bool] = True, ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: """ @@ -385,13 +385,12 @@ def get_likelihood( autoencoder_model: first stage model. diffusion_model: model to compute likelihood from scheduler: diffusion scheduler. If none provided will use the class attribute scheduler - predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. - + predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. save_intermediates: save the intermediate spatial KL maps conditioning: Conditioning for network input. original_input_range: the [min,max] intensity range of the input data before any scaling was applied. scaled_input_range: the [min,max] intensity range of the input data after scaling. - verbose: if true, prints the progression bar of the sampling process. + verbose: if tguarue, prints the progression bar of the sampling process. """ with torch.no_grad(): From 663c1bd8910819566aa7b5e2874ea7ba1dd8c926 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 13 Dec 2022 12:36:23 -0600 Subject: [PATCH 39/42] Include v-prediction and use scheduler prediction_type attribute --- generative/inferers/inferer.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index db495ca2..e4494f27 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -107,7 +107,6 @@ def get_likelihood( inputs: torch.Tensor, diffusion_model: Callable[..., torch.Tensor], scheduler: Optional[Callable[..., torch.Tensor]] = None, - predict_epsilon: bool = True, save_intermediates: Optional[bool] = False, conditioning: Optional[torch.Tensor] = None, original_input_range: Optional[Tuple] = (0, 255), @@ -121,7 +120,6 @@ def get_likelihood( inputs: input images, NxCxHxW[xD] diffusion_model: model to compute likelihood from scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. - predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. save_intermediates: save the intermediate spatial KL maps conditioning: Conditioning for network input. original_input_range: the [min,max] intensity range of the input data before any scaling was applied. @@ -161,11 +159,12 @@ def get_likelihood( # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if predict_epsilon: + if scheduler.prediction_type == "epsilon": pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - else: + elif scheduler.prediction_type == "sample": pred_original_sample = model_output - + elif scheduler.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output # 3. Clip "predicted x_0" if scheduler.clip_sample: pred_original_sample = torch.clamp(pred_original_sample, -1, 1) @@ -264,6 +263,7 @@ def _get_decoder_log_likelihood( assert log_probs.shape == inputs.shape return log_probs + class LatentDiffusionInferer(DiffusionInferer): """ LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can @@ -370,7 +370,6 @@ def get_likelihood( autoencoder_model: Callable[..., torch.Tensor], diffusion_model: Callable[..., torch.Tensor], scheduler: Optional[Callable[..., torch.Tensor]] = None, - predict_epsilon: bool = True, save_intermediates: Optional[bool] = False, conditioning: Optional[torch.Tensor] = None, original_input_range: Optional[Tuple] = (0, 255), @@ -385,7 +384,6 @@ def get_likelihood( autoencoder_model: first stage model. diffusion_model: model to compute likelihood from scheduler: diffusion scheduler. If none provided will use the class attribute scheduler - predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. save_intermediates: save the intermediate spatial KL maps conditioning: Conditioning for network input. original_input_range: the [min,max] intensity range of the input data before any scaling was applied. @@ -399,7 +397,6 @@ def get_likelihood( inputs=latents, diffusion_model=diffusion_model, scheduler=scheduler, - predict_epsilon=predict_epsilon, save_intermediates=save_intermediates, conditioning=conditioning, verbose=verbose, From a84743fb8b00c841e30644086f099714d3216fc6 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 14 Dec 2022 12:09:50 -0600 Subject: [PATCH 40/42] Adds decorators for no_grad --- generative/inferers/inferer.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 5584a851..ee608305 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -102,6 +102,7 @@ def sample( else: return image + @torch.no_grad() def get_likelihood( self, inputs: torch.Tensor, @@ -364,6 +365,7 @@ def sample( else: return image + @torch.no_grad() def get_likelihood( self, inputs: torch.Tensor, @@ -391,14 +393,13 @@ def get_likelihood( verbose: if true, prints the progression bar of the sampling process. """ - with torch.no_grad(): - latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor - outputs = super().get_likelihood( - inputs=latents, - diffusion_model=diffusion_model, - scheduler=scheduler, - save_intermediates=save_intermediates, - conditioning=conditioning, - verbose=verbose, - ) + latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + outputs = super().get_likelihood( + inputs=latents, + diffusion_model=diffusion_model, + scheduler=scheduler, + save_intermediates=save_intermediates, + conditioning=conditioning, + verbose=verbose, + ) return outputs From b005464e22aa4285eb49eb0b82d644caed3ba15a Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 14 Dec 2022 12:31:22 -0600 Subject: [PATCH 41/42] Adds option to resample latent likelihoods spatially --- generative/inferers/inferer.py | 17 +++++++++++++++ tests/test_latent_diffusion_inferer.py | 29 ++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index ee608305..2acc2938 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -377,6 +377,8 @@ def get_likelihood( original_input_range: Optional[Tuple] = (0, 255), scaled_input_range: Optional[Tuple] = (0, 1), verbose: Optional[bool] = True, + resample_latent_likelihoods: Optional[bool] = False, + resample_interpolation_mode: Optional[str] = "bilinear", ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: """ Computes the likelihoods for an input. @@ -391,6 +393,9 @@ def get_likelihood( original_input_range: the [min,max] intensity range of the input data before any scaling was applied. scaled_input_range: the [min,max] intensity range of the input data after scaling. verbose: if true, prints the progression bar of the sampling process. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest' or 'bilinear' """ latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor @@ -402,4 +407,16 @@ def get_likelihood( conditioning=conditioning, verbose=verbose, ) + if save_intermediates and resample_latent_likelihoods: + intermediates = outputs[1] + from torchvision.transforms import Resize + + interpolation_modes = {"nearest": 0, "bilinear": 2} + if resample_interpolation_mode not in interpolation_modes.keys(): + raise ValueError( + f"resample_interpolation mode should be either nearest or bilinear, not {resample_interpolation_mode}" + ) + resizer = Resize(size=inputs.shape[2:], interpolation=interpolation_modes[resample_interpolation_mode]) + intermediates = [resizer(x) for x in intermediates] + outputs = (outputs[0], intermediates) return outputs diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index dc9e26ab..57d251c9 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -184,6 +184,35 @@ def test_get_likelihoods(self, model_type, autoencoder_params, stage_2_params, i self.assertEqual(len(intermediates), 10) self.assertEqual(intermediates[0].shape, latent_shape) + @parameterized.expand(TEST_CASES) + def test_resample_likelihoods(self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape): + if model_type == "AutoencoderKL": + autoencoder_model = AutoencoderKL(**autoencoder_params) + if model_type == "VQVAE": + autoencoder_model = VQVAE(**autoencoder_params) + stage_2 = DiffusionModelUNet(**stage_2_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + autoencoder_model.to(device) + stage_2.to(device) + autoencoder_model.eval() + autoencoder_model.train() + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler( + num_train_timesteps=10, + ) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=autoencoder_model, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + resample_latent_likelihoods=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) + if __name__ == "__main__": unittest.main() From d968a0e17c950aafb7d2955c598987d2e03e3400 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 14 Dec 2022 12:36:46 -0600 Subject: [PATCH 42/42] Updates docstring --- generative/inferers/inferer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 2acc2938..7f667ec3 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -381,7 +381,7 @@ def get_likelihood( resample_interpolation_mode: Optional[str] = "bilinear", ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: """ - Computes the likelihoods for an input. + Computes the likelihoods of the latent representations of the input. Args: inputs: input images, NxCxHxW[xD]