diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 9785c585..7f667ec3 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -10,6 +10,7 @@ # limitations under the License. +import math from typing import Callable, List, Optional, Tuple, Union import torch @@ -66,7 +67,7 @@ def sample( intermediate_steps: Optional[int] = 100, conditioning: Optional[torch.Tensor] = None, verbose: Optional[bool] = True, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: """ Args: input_noise: random noise, of the same shape as the desired sample. @@ -101,6 +102,168 @@ def sample( else: return image + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + scheduler: Optional[Callable[..., torch.Tensor]] = None, + save_intermediates: Optional[bool] = False, + conditioning: Optional[torch.Tensor] = None, + original_input_range: Optional[Tuple] = (0, 255), + scaled_input_range: Optional[Tuple] = (0, 1), + verbose: Optional[bool] = True, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: + """ + Computes the likelihoods for an input. + + Args: + inputs: input images, NxCxHxW[xD] + diffusion_model: model to compute likelihood from + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + """ + + if not scheduler: + scheduler = self.scheduler + if scheduler._get_name() != "DDPMScheduler": + raise NotImplementedError( + f"Likelihood computation is only compatible with DDPMScheduler," + f" you are using {scheduler._get_name()}" + ) + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + noise = torch.randn_like(inputs).to(inputs.device) + total_kl = torch.zeros((inputs.shape[0])).to(inputs.device) + for t in progress_bar: + timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) + # get the model's predicted mean, and variance if it is predicted + if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[t] + alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if scheduler.prediction_type == "epsilon": + pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif scheduler.prediction_type == "sample": + pred_original_sample = model_output + elif scheduler.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output + # 3. Clip "predicted x_0" + if scheduler.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t + current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample ยต_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image + + # get the posterior mean and variance + posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) + + log_posterior_variance = torch.log(posterior_variance) + log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + + if t == 0: + # compute -log p(x_0|x_1) + kl = -self._get_decoder_log_likelihood( + inputs=inputs, + means=predicted_mean, + log_scales=0.5 * log_predicted_variance, + original_input_range=original_input_range, + scaled_input_range=scaled_input_range, + ) + else: + # compute kl between two normals + kl = 0.5 * ( + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + ) + total_kl += kl.view(kl.shape[0], -1).mean(axis=1) + if save_intermediates: + intermediates.append(kl.cpu()) + + if save_intermediates: + return total_kl, intermediates + else: + return total_kl + + def _approx_standard_normal_cdf(self, x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. Code adapted from https://github.com/openai/improved-diffusion. + """ + + return 0.5 * ( + 1.0 + torch.tanh(torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3))) + ) + + def _get_decoder_log_likelihood( + self, + inputs: torch.Tensor, + means: torch.Tensor, + log_scales: torch.Tensor, + original_input_range: Optional[Tuple] = [0, 255], + scaled_input_range: Optional[Tuple] = [0, 1], + ) -> torch.Tensor: + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. Code adapted from https://github.com/openai/improved-diffusion. + + Args: + input: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + means: the Gaussian mean Tensor. + log_scales: the Gaussian log stddev Tensor. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + """ + assert inputs.shape == means.shape + bin_width = (scaled_input_range[1] - scaled_input_range[0]) / ( + original_input_range[1] - original_input_range[0] + ) + centered_x = inputs - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + bin_width / 2) + cdf_plus = self._approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - bin_width / 2) + cdf_min = self._approx_standard_normal_cdf(min_in) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + inputs < -0.999, + log_cdf_plus, + torch.where(inputs > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == inputs.shape + return log_probs + class LatentDiffusionInferer(DiffusionInferer): """ @@ -201,3 +364,59 @@ def sample( else: return image + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + scheduler: Optional[Callable[..., torch.Tensor]] = None, + save_intermediates: Optional[bool] = False, + conditioning: Optional[torch.Tensor] = None, + original_input_range: Optional[Tuple] = (0, 255), + scaled_input_range: Optional[Tuple] = (0, 1), + verbose: Optional[bool] = True, + resample_latent_likelihoods: Optional[bool] = False, + resample_interpolation_mode: Optional[str] = "bilinear", + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: + """ + Computes the likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + autoencoder_model: first stage model. + diffusion_model: model to compute likelihood from + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest' or 'bilinear' + """ + + latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + outputs = super().get_likelihood( + inputs=latents, + diffusion_model=diffusion_model, + scheduler=scheduler, + save_intermediates=save_intermediates, + conditioning=conditioning, + verbose=verbose, + ) + if save_intermediates and resample_latent_likelihoods: + intermediates = outputs[1] + from torchvision.transforms import Resize + + interpolation_modes = {"nearest": 0, "bilinear": 2} + if resample_interpolation_mode not in interpolation_modes.keys(): + raise ValueError( + f"resample_interpolation mode should be either nearest or bilinear, not {resample_interpolation_mode}" + ) + resizer = Resize(size=inputs.shape[2:], interpolation=interpolation_modes[resample_interpolation_mode]) + intermediates = [resizer(x) for x in intermediates] + outputs = (outputs[0], intermediates) + return outputs diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index f0a64482..9f5ca107 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -90,7 +90,7 @@ def __init__( self.clip_sample = clip_sample self.variance_type = variance_type - # setable values + # settable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) @@ -109,9 +109,34 @@ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, to ].copy() self.timesteps = torch.from_numpy(timesteps).to(device) + def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: + """ + Compute the mean of the posterior at timestep t. + + Args: + timestep: current timestep. + x0: the noise-free input. + x_t: the input noised to timestep t. + + Returns: + Returns the mean + """ + # these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0), + # (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf) + alpha_t = self.alphas[timestep] + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + + x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t) + x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) + + mean = x_0_coefficient * x_0 + x_t_coefficient * x_t + + return mean + def _get_variance(self, timestep: int, predicted_variance: Optional[torch.Tensor] = None) -> torch.Tensor: """ - Compute the variance. + Compute the variance of the posterior at timestep t. Args: timestep: current timestep. @@ -127,7 +152,6 @@ def _get_variance(self, timestep: int, predicted_variance: Optional[torch.Tensor # and sample from it to get previous sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep] - # hacks - were probably added for training stability if self.variance_type == "fixed_small": variance = torch.clamp(variance, min=1e-20) diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py index 8b3cd511..74f725c4 100644 --- a/tests/test_diffusion_inferer.py +++ b/tests/test_diffusion_inferer.py @@ -141,6 +141,37 @@ def test_sampler_conditioned(self, model_params, input_shape): ) self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) + def test_get_likelihood(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler( + num_train_timesteps=10, + ) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + likelihood, intermediates = inferer.get_likelihood( + inputs=input, diffusion_model=model, scheduler=scheduler, save_intermediates=True + ) + self.assertEqual(intermediates[0].shape, input.shape) + self.assertEqual(likelihood.shape[0], input.shape[0]) + + def test_normal_cdf(self): + from scipy.stats import norm + + scheduler = DDPMScheduler( + num_train_timesteps=10, + ) + inferer = DiffusionInferer(scheduler=scheduler) + + x = torch.linspace(-10, 10, 20) + cdf_approx = inferer._approx_standard_normal_cdf(x) + cdf_true = norm.cdf(x) + torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index 450b940e..57d251c9 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -156,6 +156,63 @@ def test_sample_intermediates(self, model_type, autoencoder_params, stage_2_para self.assertEqual(len(intermediates), 10) self.assertEqual(intermediates[0].shape, input_shape) + @parameterized.expand(TEST_CASES) + def test_get_likelihoods(self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape): + if model_type == "AutoencoderKL": + autoencoder_model = AutoencoderKL(**autoencoder_params) + if model_type == "VQVAE": + autoencoder_model = VQVAE(**autoencoder_params) + stage_2 = DiffusionModelUNet(**stage_2_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + autoencoder_model.to(device) + stage_2.to(device) + autoencoder_model.eval() + autoencoder_model.train() + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler( + num_train_timesteps=10, + ) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=autoencoder_model, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_resample_likelihoods(self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape): + if model_type == "AutoencoderKL": + autoencoder_model = AutoencoderKL(**autoencoder_params) + if model_type == "VQVAE": + autoencoder_model = VQVAE(**autoencoder_params) + stage_2 = DiffusionModelUNet(**stage_2_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + autoencoder_model.to(device) + stage_2.to(device) + autoencoder_model.eval() + autoencoder_model.train() + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler( + num_train_timesteps=10, + ) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=autoencoder_model, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + resample_latent_likelihoods=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) + if __name__ == "__main__": unittest.main()