Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
f675b39
Fixes return type in sample
marksgraham Dec 2, 2022
ec2f1c1
Adds method to compute posterior mean
marksgraham Dec 2, 2022
8e68dd9
Initial code for computing likelihood
marksgraham Dec 2, 2022
b803356
Fixes bug in get_mean
marksgraham Dec 5, 2022
60ad56c
Calculates mean/var from epsilon
marksgraham Dec 5, 2022
d19185a
Fixes bug in predicting input from noise
marksgraham Dec 9, 2022
b6b13ab
Adds decoder log-likelihood
marksgraham Dec 12, 2022
efc4020
Adds log-likelihood calculation for latent diffusion model
marksgraham Dec 12, 2022
7422fbf
Fixes return type in sample
marksgraham Dec 2, 2022
17b7974
Adds method to compute posterior mean
marksgraham Dec 2, 2022
440b1db
Initial code for computing likelihood
marksgraham Dec 2, 2022
b0b07ec
Fixes bug in get_mean
marksgraham Dec 5, 2022
86051de
Calculates mean/var from epsilon
marksgraham Dec 5, 2022
0f499ec
Fixes bug in predicting input from noise
marksgraham Dec 9, 2022
1369812
Adds decoder log-likelihood
marksgraham Dec 12, 2022
177d692
Adds log-likelihood calculation for latent diffusion model
marksgraham Dec 12, 2022
547d247
Adds tests
marksgraham Dec 12, 2022
043dda7
Adds latent tests
marksgraham Dec 12, 2022
58c7814
Fix rebase conflict
marksgraham Dec 12, 2022
14dc496
Pass input scalings to decoder calc
marksgraham Dec 13, 2022
fedf066
Fix arg and docstring
marksgraham Dec 13, 2022
fa8031d
Fixes return type in sample
marksgraham Dec 2, 2022
1164b53
Adds method to compute posterior mean
marksgraham Dec 2, 2022
91a28d2
Initial code for computing likelihood
marksgraham Dec 2, 2022
9f217f4
Fixes bug in get_mean
marksgraham Dec 5, 2022
64c62d7
Calculates mean/var from epsilon
marksgraham Dec 5, 2022
4ac2758
Fixes bug in predicting input from noise
marksgraham Dec 9, 2022
9fbf962
Adds decoder log-likelihood
marksgraham Dec 12, 2022
e45fad9
Adds log-likelihood calculation for latent diffusion model
marksgraham Dec 12, 2022
432c1d0
Adds tests
marksgraham Dec 12, 2022
92c7560
Adds latent tests
marksgraham Dec 12, 2022
62b0f70
Adds method to compute posterior mean
marksgraham Dec 2, 2022
3ed147a
Initial code for computing likelihood
marksgraham Dec 2, 2022
1e847e0
Fixes bug in get_mean
marksgraham Dec 5, 2022
6d7a0ad
Calculates mean/var from epsilon
marksgraham Dec 5, 2022
bdbc303
Fixes bug in predicting input from noise
marksgraham Dec 9, 2022
04f5c68
Adds decoder log-likelihood
marksgraham Dec 12, 2022
e1b3110
Pass input scalings to decoder calc
marksgraham Dec 13, 2022
1b73c20
Fix arg and docstring
marksgraham Dec 13, 2022
663c1bd
Include v-prediction and use scheduler prediction_type attribute
marksgraham Dec 13, 2022
d1835fa
Fixes merge conflicts
marksgraham Dec 14, 2022
a84743f
Adds decorators for no_grad
marksgraham Dec 14, 2022
b005464
Adds option to resample latent likelihoods spatially
marksgraham Dec 14, 2022
d968a0e
Updates docstring
marksgraham Dec 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 220 additions & 1 deletion generative/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.


import math
from typing import Callable, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -66,7 +67,7 @@ def sample(
intermediate_steps: Optional[int] = 100,
conditioning: Optional[torch.Tensor] = None,
verbose: Optional[bool] = True,
) -> torch.Tensor:
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
"""
Args:
input_noise: random noise, of the same shape as the desired sample.
Expand Down Expand Up @@ -101,6 +102,168 @@ def sample(
else:
return image

@torch.no_grad()
def get_likelihood(
self,
inputs: torch.Tensor,
diffusion_model: Callable[..., torch.Tensor],
scheduler: Optional[Callable[..., torch.Tensor]] = None,
save_intermediates: Optional[bool] = False,
conditioning: Optional[torch.Tensor] = None,
original_input_range: Optional[Tuple] = (0, 255),
scaled_input_range: Optional[Tuple] = (0, 1),
verbose: Optional[bool] = True,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
"""
Computes the likelihoods for an input.

Args:
inputs: input images, NxCxHxW[xD]
diffusion_model: model to compute likelihood from
scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
save_intermediates: save the intermediate spatial KL maps
conditioning: Conditioning for network input.
original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
scaled_input_range: the [min,max] intensity range of the input data after scaling.
verbose: if true, prints the progression bar of the sampling process.
"""

if not scheduler:
scheduler = self.scheduler
if scheduler._get_name() != "DDPMScheduler":
raise NotImplementedError(
f"Likelihood computation is only compatible with DDPMScheduler,"
f" you are using {scheduler._get_name()}"
)
if verbose and has_tqdm:
progress_bar = tqdm(scheduler.timesteps)
else:
progress_bar = iter(scheduler.timesteps)
intermediates = []
noise = torch.randn_like(inputs).to(inputs.device)
total_kl = torch.zeros((inputs.shape[0])).to(inputs.device)
for t in progress_bar:
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning)
# get the model's predicted mean, and variance if it is predicted
if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1)
else:
predicted_variance = None

# 1. compute alphas, betas
alpha_prod_t = scheduler.alphas_cumprod[t]
alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if scheduler.prediction_type == "epsilon":
pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif scheduler.prediction_type == "sample":
pred_original_sample = model_output
elif scheduler.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output
# 3. Clip "predicted x_0"
if scheduler.clip_sample:
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)

# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t
current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t

# 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image

# get the posterior mean and variance
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)

log_posterior_variance = torch.log(posterior_variance)
log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance

if t == 0:
# compute -log p(x_0|x_1)
kl = -self._get_decoder_log_likelihood(
inputs=inputs,
means=predicted_mean,
log_scales=0.5 * log_predicted_variance,
original_input_range=original_input_range,
scaled_input_range=scaled_input_range,
)
else:
# compute kl between two normals
kl = 0.5 * (
-1.0
+ log_predicted_variance
- log_posterior_variance
+ torch.exp(log_posterior_variance - log_predicted_variance)
+ ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance)
)
total_kl += kl.view(kl.shape[0], -1).mean(axis=1)
if save_intermediates:
intermediates.append(kl.cpu())

if save_intermediates:
return total_kl, intermediates
else:
return total_kl

def _approx_standard_normal_cdf(self, x):
"""
A fast approximation of the cumulative distribution function of the
standard normal. Code adapted from https://github.com/openai/improved-diffusion.
"""

return 0.5 * (
1.0 + torch.tanh(torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3)))
)

def _get_decoder_log_likelihood(
self,
inputs: torch.Tensor,
means: torch.Tensor,
log_scales: torch.Tensor,
original_input_range: Optional[Tuple] = [0, 255],
scaled_input_range: Optional[Tuple] = [0, 1],
) -> torch.Tensor:
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
given image. Code adapted from https://github.com/openai/improved-diffusion.

Args:
input: the target images. It is assumed that this was uint8 values,
rescaled to the range [-1, 1].
means: the Gaussian mean Tensor.
log_scales: the Gaussian log stddev Tensor.
original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
scaled_input_range: the [min,max] intensity range of the input data after scaling.
"""
assert inputs.shape == means.shape
bin_width = (scaled_input_range[1] - scaled_input_range[0]) / (
original_input_range[1] - original_input_range[0]
)
centered_x = inputs - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + bin_width / 2)
cdf_plus = self._approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - bin_width / 2)
cdf_min = self._approx_standard_normal_cdf(min_in)
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
inputs < -0.999,
log_cdf_plus,
torch.where(inputs > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))),
)
assert log_probs.shape == inputs.shape
return log_probs


class LatentDiffusionInferer(DiffusionInferer):
"""
Expand Down Expand Up @@ -201,3 +364,59 @@ def sample(

else:
return image

@torch.no_grad()
def get_likelihood(
self,
inputs: torch.Tensor,
autoencoder_model: Callable[..., torch.Tensor],
diffusion_model: Callable[..., torch.Tensor],
scheduler: Optional[Callable[..., torch.Tensor]] = None,
save_intermediates: Optional[bool] = False,
conditioning: Optional[torch.Tensor] = None,
original_input_range: Optional[Tuple] = (0, 255),
scaled_input_range: Optional[Tuple] = (0, 1),
verbose: Optional[bool] = True,
resample_latent_likelihoods: Optional[bool] = False,
resample_interpolation_mode: Optional[str] = "bilinear",
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
"""
Computes the likelihoods of the latent representations of the input.

Args:
inputs: input images, NxCxHxW[xD]
autoencoder_model: first stage model.
diffusion_model: model to compute likelihood from
scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
save_intermediates: save the intermediate spatial KL maps
conditioning: Conditioning for network input.
original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
scaled_input_range: the [min,max] intensity range of the input data after scaling.
verbose: if true, prints the progression bar of the sampling process.
resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial
dimension as the input images.
resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest' or 'bilinear'
"""

latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
outputs = super().get_likelihood(
inputs=latents,
diffusion_model=diffusion_model,
scheduler=scheduler,
save_intermediates=save_intermediates,
conditioning=conditioning,
verbose=verbose,
)
if save_intermediates and resample_latent_likelihoods:
intermediates = outputs[1]
from torchvision.transforms import Resize

interpolation_modes = {"nearest": 0, "bilinear": 2}
if resample_interpolation_mode not in interpolation_modes.keys():
raise ValueError(
f"resample_interpolation mode should be either nearest or bilinear, not {resample_interpolation_mode}"
)
resizer = Resize(size=inputs.shape[2:], interpolation=interpolation_modes[resample_interpolation_mode])
intermediates = [resizer(x) for x in intermediates]
outputs = (outputs[0], intermediates)
return outputs
30 changes: 27 additions & 3 deletions generative/networks/schedulers/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
self.clip_sample = clip_sample
self.variance_type = variance_type

# setable values
# settable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())

Expand All @@ -109,9 +109,34 @@ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, to
].copy()
self.timesteps = torch.from_numpy(timesteps).to(device)

def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor:
"""
Compute the mean of the posterior at timestep t.

Args:
timestep: current timestep.
x0: the noise-free input.
x_t: the input noised to timestep t.

Returns:
Returns the mean
"""
# these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0),
# (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf)
alpha_t = self.alphas[timestep]
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one

x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t)
x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t)

mean = x_0_coefficient * x_0 + x_t_coefficient * x_t

return mean

def _get_variance(self, timestep: int, predicted_variance: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Compute the variance.
Compute the variance of the posterior at timestep t.

Args:
timestep: current timestep.
Expand All @@ -127,7 +152,6 @@ def _get_variance(self, timestep: int, predicted_variance: Optional[torch.Tensor
# and sample from it to get previous sample
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep]

# hacks - were probably added for training stability
if self.variance_type == "fixed_small":
variance = torch.clamp(variance, min=1e-20)
Expand Down
31 changes: 31 additions & 0 deletions tests/test_diffusion_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,37 @@ def test_sampler_conditioned(self, model_params, input_shape):
)
self.assertEqual(len(intermediates), 10)

@parameterized.expand(TEST_CASES)
def test_get_likelihood(self, model_params, input_shape):
model = DiffusionModelUNet(**model_params)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
input = torch.randn(input_shape).to(device)
scheduler = DDPMScheduler(
num_train_timesteps=10,
)
inferer = DiffusionInferer(scheduler=scheduler)
scheduler.set_timesteps(num_inference_steps=10)
likelihood, intermediates = inferer.get_likelihood(
inputs=input, diffusion_model=model, scheduler=scheduler, save_intermediates=True
)
self.assertEqual(intermediates[0].shape, input.shape)
self.assertEqual(likelihood.shape[0], input.shape[0])

def test_normal_cdf(self):
from scipy.stats import norm

scheduler = DDPMScheduler(
num_train_timesteps=10,
)
inferer = DiffusionInferer(scheduler=scheduler)

x = torch.linspace(-10, 10, 20)
cdf_approx = inferer._approx_standard_normal_cdf(x)
cdf_true = norm.cdf(x)
torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5)


if __name__ == "__main__":
unittest.main()
Loading