From f2f154e9674dff0aeeeceaff95e5bd71be6b7032 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Mon, 27 Nov 2023 16:14:57 +0000 Subject: [PATCH 1/5] Added normal and latent inferers for ControlNet. Added tests (copied from the normal inferer tests, but with the addition of controlnet support). --- generative/inferers/__init__.py | 3 +- generative/inferers/inferer.py | 635 ++++++++++++- generative/networks/nets/controlnet.py | 8 +- tests/test_controlnet_inferers.py | 1174 ++++++++++++++++++++++++ 4 files changed, 1813 insertions(+), 7 deletions(-) create mode 100644 tests/test_controlnet_inferers.py diff --git a/generative/inferers/__init__.py b/generative/inferers/__init__.py index e6402093..49c195b6 100644 --- a/generative/inferers/__init__.py +++ b/generative/inferers/__init__.py @@ -11,4 +11,5 @@ from __future__ import annotations -from .inferer import DiffusionInferer, LatentDiffusionInferer, VQVAETransformerInferer +from .inferer import DiffusionInferer, LatentDiffusionInferer, VQVAETransformerInferer, \ + ControlNetDiffusionInferer, ControlNetLatentDiffusionInferer diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 1d1c3a9c..196c9aea 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -13,7 +13,6 @@ import math from collections.abc import Callable, Sequence - import torch import torch.nn as nn import torch.nn.functional as F @@ -322,7 +321,6 @@ def _get_decoder_log_likelihood( assert log_probs.shape == inputs.shape return log_probs - class LatentDiffusionInferer(DiffusionInferer): """ LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can @@ -571,6 +569,639 @@ def get_likelihood( outputs = (outputs[0], intermediates) return outputs +class ControlNetDiffusionInferer(Inferer): + """ + ControlNetDiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal + forward pass for a training iteration, and sample from the model, supporting ControlNet-based conditioning. + Args: + scheduler: diffusion scheduler. + """ + + def __init__(self, scheduler: nn.Module) -> None: + Inferer.__init__(self) + self.scheduler = scheduler + + def __call__( + self, + inputs: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + controlnet: Callable[..., torch.Tensor], + noise: torch.Tensor, + timesteps: torch.Tensor, + cn_cond: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: Input image to which noise is added. + diffusion_model: diffusion model. + controlnet: controlnet sub-network. + noise: random noise, of the same shape as the input. + timesteps: random timesteps. + cn_cond: conditioning image for the ControlNet. + condition: Conditioning for network input. + mode: Conditioning mode for the network. + seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be + provided on the forward (for SPADE-like AE or SPADE-like DM) + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + down_block_res_samples, mid_block_res_sample = controlnet(x=noisy_image, + timesteps=timesteps, + controlnet_cond=cn_cond, + ) + if mode == "concat": + noisy_image = torch.cat([noisy_image, condition], dim=1) + condition = None + + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + prediction = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition, seg=seg, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample + ) + else: + prediction = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample + ) + + return prediction + + @torch.no_grad() + def sample( + self, + input_noise: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + controlnet: Callable[..., torch.Tensor], + cn_cond: torch.Tensor, + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired sample. + diffusion_model: model to sample from. + controlnet: controlnet sub-network. + cn_cond: conditioning image for the ControlNet. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + + if not scheduler: + scheduler = self.scheduler + image = input_noise + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + for t in progress_bar: + # 1. ControlNet forward + down_block_res_samples, mid_block_res_sample = controlnet( + x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), + controlnet_cond= cn_cond, + ) + # 2. predict noise model_output + if mode == "concat": + model_input = torch.cat([image, conditioning], dim=1) + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + model_output = diffusion_model( + model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None, seg=seg, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample + ) + else: + model_output = diffusion_model( + model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample + ) + else: + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + model_output = diffusion_model( + image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning, seg=seg, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample + ) + else: + model_output = diffusion_model( + image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample + ) + + # 3. compute previous image: x_t -> x_t-1 + image, _ = scheduler.step(model_output, t, image) + if save_intermediates and t % intermediate_steps == 0: + intermediates.append(image) + if save_intermediates: + return image, intermediates + else: + return image + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + controlnet: Callable[..., torch.Tensor], + cn_cond: torch.Tensor, + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods for an input. + + Args: + inputs: input images, NxCxHxW[xD] + diffusion_model: model to compute likelihood from + controlnet: controlnet sub-network. + cn_cond: conditioning image for the ControlNet. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + + if not scheduler: + scheduler = self.scheduler + if scheduler._get_name() != "DDPMScheduler": + raise NotImplementedError( + f"Likelihood computation is only compatible with DDPMScheduler," + f" you are using {scheduler._get_name()}" + ) + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + noise = torch.randn_like(inputs).to(inputs.device) + total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) + for t in progress_bar: + timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), + controlnet_cond=cn_cond, + ) + + if mode == "concat": + noisy_image = torch.cat([noisy_image, conditioning], dim=1) + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None, seg=seg, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample + ) + else: + model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample + ) + else: + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning, seg=seg, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample + ) + else: + model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample + ) + # get the model's predicted mean, and variance if it is predicted + if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[t] + alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if scheduler.prediction_type == "epsilon": + pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif scheduler.prediction_type == "sample": + pred_original_sample = model_output + elif scheduler.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t ** 0.5) * noisy_image - (beta_prod_t ** 0.5) * model_output + # 3. Clip "predicted x_0" + if scheduler.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t + current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample ยต_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image + + # get the posterior mean and variance + posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) + + log_posterior_variance = torch.log(posterior_variance) + log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + + if t == 0: + # compute -log p(x_0|x_1) + kl = -self._get_decoder_log_likelihood( + inputs=inputs, + means=predicted_mean, + log_scales=0.5 * log_predicted_variance, + original_input_range=original_input_range, + scaled_input_range=scaled_input_range, + ) + else: + # compute kl between two normals + kl = 0.5 * ( + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + ) + total_kl += kl.view(kl.shape[0], -1).mean(axis=1) + if save_intermediates: + intermediates.append(kl.cpu()) + + if save_intermediates: + return total_kl, intermediates + else: + return total_kl + + def _approx_standard_normal_cdf(self, x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. Code adapted from https://github.com/openai/improved-diffusion. + """ + + return 0.5 * ( + 1.0 + torch.tanh( + torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3))) + ) + + def _get_decoder_log_likelihood( + self, + inputs: torch.Tensor, + means: torch.Tensor, + log_scales: torch.Tensor, + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + ) -> torch.Tensor: + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. Code adapted from https://github.com/openai/improved-diffusion. + + Args: + input: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + means: the Gaussian mean Tensor. + log_scales: the Gaussian log stddev Tensor. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + """ + assert inputs.shape == means.shape + bin_width = (scaled_input_range[1] - scaled_input_range[0]) / ( + original_input_range[1] - original_input_range[0] + ) + centered_x = inputs - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + bin_width / 2) + cdf_plus = self._approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - bin_width / 2) + cdf_min = self._approx_standard_normal_cdf(min_in) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + inputs < -0.999, + log_cdf_plus, + torch.where(inputs > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == inputs.shape + return log_probs + +class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer): + """ + ControlNetLatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, controlnet, + and a scheduler, and can be used to perform a signal forward pass for a training iteration, and sample from + the model. + + Args: + scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents. + scale_factor: scale factor to multiply the values of the latent representation before processing it by the + second stage. + ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape. + autoencoder_latent_shape: autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a + difference between the autoencoder's latent shape and the DM shape. + """ + + def __init__( + self, + scheduler: nn.Module, + scale_factor: float = 1.0, + ldm_latent_shape: list | None = None, + autoencoder_latent_shape: list | None = None, + ) -> None: + super().__init__(scheduler=scheduler) + self.scale_factor = scale_factor + if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None): + raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None" "and vice versa.") + self.ldm_latent_shape = ldm_latent_shape + self.autoencoder_latent_shape = autoencoder_latent_shape + if self.ldm_latent_shape is not None: + self.ldm_resizer = SpatialPad(spatial_size=[-1] + self.ldm_latent_shape) + self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1] + self.autoencoder_latent_shape) + + def __call__( + self, + inputs: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + controlnet: Callable[..., torch.Tensor], + noise: torch.Tensor, + timesteps: torch.Tensor, + cn_cond: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: input image to which the latent representation will be extracted and noise is added. + autoencoder_model: first stage model. + diffusion_model: diffusion model. + controlnet: instance of ControlNet model + noise: random noise, of the same shape as the latent representation. + timesteps: random timesteps. + cn_cond: conditioning tensor for the ControlNet network + condition: conditioning for network input. + mode: Conditioning mode for the network. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + with torch.no_grad(): + latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + + if self.ldm_latent_shape is not None: + latent = self.ldm_resizer(latent) + if cn_cond.shape[2:] != latent.shape[2:]: + cn_cond = F.interpolate(cn_cond, latent.shape[2:]) + + + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + prediction = super().__call__( + inputs=latent, + diffusion_model=diffusion_model, + controlnet = controlnet, + noise=noise, + timesteps=timesteps, + cn_cond = cn_cond, + condition=condition, + mode=mode, + seg=seg, + ) + else: + prediction = super().__call__( + inputs=latent, + diffusion_model=diffusion_model, + controlnet = controlnet, + noise=noise, + timesteps=timesteps, + cn_cond = cn_cond, + condition=condition, + mode=mode, + ) + + return prediction + + @torch.no_grad() + def sample( + self, + input_noise: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + controlnet: Callable[..., torch.Tensor], + cn_cond: torch.Tensor, + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired latent representation. + autoencoder_model: first stage model. + diffusion_model: model to sample from. + controlnet: instance of ControlNet model. + cn_cond: conditioning tensor for the ControlNet network. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + + if ( + isinstance(autoencoder_model, SPADEAutoencoderKL) + and isinstance(diffusion_model, SPADEDiffusionModelUNet) + and autoencoder_model.decoder.label_nc != diffusion_model.label_nc + ): + raise ValueError( + "If both autoencoder_model and diffusion_model implement SPADE, the number of semantic" + "labels for each must be compatible. " + ) + + if cn_cond.shape[2:] != input_noise.shape[2:]: + cn_cond = F.interpolate(cn_cond, input_noise.shape[2:]) + + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + outputs = super().sample( + input_noise=input_noise, + diffusion_model=diffusion_model, + controlnet = controlnet, + cn_cond = cn_cond, + scheduler=scheduler, + save_intermediates=save_intermediates, + intermediate_steps=intermediate_steps, + conditioning=conditioning, + mode=mode, + verbose=verbose, + seg=seg, + ) + else: + outputs = super().sample( + input_noise=input_noise, + diffusion_model=diffusion_model, + controlnet=controlnet, + cn_cond=cn_cond, + scheduler=scheduler, + save_intermediates=save_intermediates, + intermediate_steps=intermediate_steps, + conditioning=conditioning, + mode=mode, + verbose=verbose, + ) + + if save_intermediates: + latent, latent_intermediates = outputs + else: + latent = outputs + + if self.ldm_latent_shape is not None: + latent = self.autoencoder_resizer(latent) + latent_intermediates = [self.autoencoder_resizer(l) for l in latent_intermediates] + + image = autoencoder_model.decode_stage_2_outputs(latent / self.scale_factor) + + if save_intermediates: + intermediates = [] + for latent_intermediate in latent_intermediates: + if isinstance(autoencoder_model, SPADEAutoencoderKL): + intermediates.append( + autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor), seg=seg + ) + else: + intermediates.append( + autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor) + ) + return image, intermediates + + else: + return image + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + controlnet: Callable[..., torch.Tensor], + cn_cond: torch.Tensor, + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + autoencoder_model: first stage model. + diffusion_model: model to compute likelihood from + controlnet: instance of ControlNet model. + cn_cond: conditioning tensor for the ControlNet network. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) + + latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if cn_cond.shape[2:] != latents.shape[2:]: + cn_cond = F.interpolate(cn_cond, latents.shape[2:]) + + if self.ldm_latent_shape is not None: + latents = self.ldm_resizer(latents) + + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + outputs = super().get_likelihood( + inputs=latents, + diffusion_model=diffusion_model, + controlnet=controlnet, + cn_cond=cn_cond, + scheduler=scheduler, + save_intermediates=save_intermediates, + conditioning=conditioning, + mode=mode, + verbose=verbose, + seg=seg, + ) + else: + outputs = super().get_likelihood( + inputs=latents, + diffusion_model=diffusion_model, + controlnet=controlnet, + cn_cond=cn_cond, + scheduler=scheduler, + save_intermediates=save_intermediates, + conditioning=conditioning, + mode=mode, + verbose=verbose, + ) + if save_intermediates and resample_latent_likelihoods: + intermediates = outputs[1] + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) + intermediates = [resizer(x) for x in intermediates] + outputs = (outputs[0], intermediates) + return outputs + class VQVAETransformerInferer(Inferer): """ diff --git a/generative/networks/nets/controlnet.py b/generative/networks/nets/controlnet.py index 4eb78802..e2e664d6 100644 --- a/generative/networks/nets/controlnet.py +++ b/generative/networks/nets/controlnet.py @@ -171,20 +171,20 @@ def __init__( super().__init__() if with_conditioning is True and cross_attention_dim is None: raise ValueError( - "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "ControlNet expects dimension of the cross-attention conditioning (cross_attention_dim) " "when using with_conditioning." ) if cross_attention_dim is not None and with_conditioning is False: raise ValueError( - "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + "ControlNet expects with_conditioning=True when specifying the cross_attention_dim." ) # All number of channels should be multiple of num_groups if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): - raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups") + raise ValueError("ControlNet expects all num_channels being multiple of norm_num_groups") if len(num_channels) != len(attention_levels): - raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels") + raise ValueError("ControlNet expects num_channels being same size of attention_levels") if isinstance(num_head_channels, int): num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) diff --git a/tests/test_controlnet_inferers.py b/tests/test_controlnet_inferers.py new file mode 100644 index 00000000..6f14175b --- /dev/null +++ b/tests/test_controlnet_inferers.py @@ -0,0 +1,1174 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized +from generative.inferers import ControlNetDiffusionInferer, ControlNetLatentDiffusionInferer +from generative.networks.nets import DiffusionModelUNet, ControlNet, AutoencoderKL, \ + SPADEAutoencoderKL, SPADEDiffusionModelUNet, VQVAE +from generative.networks.schedulers import DDIMScheduler, DDPMScheduler + +CNDM_TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8 + }, + { + "spatial_dims": 2, + "in_channels": 1, + "num_channels": [8], + "attention_levels": [True], + "norm_num_groups": 8, + "num_res_blocks": 1, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16,], + "conditioning_embedding_in_channels": 1 + }, + (2, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 3, + "in_channels": 1, + "num_channels": [8], + "attention_levels": [True], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16,], + "conditioning_embedding_in_channels": 1, + }, + (2, 1, 8, 8, 8), + ], +] +LATENT_CNDM_TEST_CASES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "num_channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "num_channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16,], + "conditioning_embedding_in_channels": 1 + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "num_channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "num_channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16,], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 16, 16), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "num_channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 3, + "in_channels": 3, + "num_channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16,], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 16, 16, 16), + (1, 3, 4, 4, 4), + ], +] +LATENT_CNDM_TEST_CASES_DIFF_SHAPES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "num_channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "num_channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16,], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "num_channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "num_channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16,], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "num_channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 3, + "in_channels": 3, + "num_channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16,], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 12, 12, 12), + (1, 3, 8, 8, 8), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "num_channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "num_channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16,], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "num_channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "num_channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16,], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "num_channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "num_channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16,], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], +] + +class CN_TestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(CNDM_TEST_CASES) + def test_call(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer(inputs=input, noise=noise, diffusion_model=model, controlnet=controlnet, + timesteps=timesteps, cn_cond=mask) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(CNDM_TEST_CASES) + def test_sample_intermediates(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + noise = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, + controlnet=controlnet, cn_cond=mask, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_ddpm_sampler(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, controlnet=controlnet, cn_cond=mask, + save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_ddim_sampler(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, controlnet=controlnet, cn_cond=mask, + save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_sampler_conditioned(self, model_params, controlnet_params, input_shape): + model_params["with_conditioning"] = True + model_params["cross_attention_dim"] = 3 + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + conditioning = torch.randn([input_shape[0], 1, 3]).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_get_likelihood(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet = ControlNet(**controlnet_params) + controlnet.to(device) + controlnet.eval() + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + likelihood, intermediates = inferer.get_likelihood( + inputs=input, diffusion_model=model, scheduler=scheduler, controlnet=controlnet, cn_cond=mask, + save_intermediates=True + ) + self.assertEqual(intermediates[0].shape, input.shape) + self.assertEqual(likelihood.shape[0], input.shape[0]) + + def test_normal_cdf(self): + from scipy.stats import norm + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + x = torch.linspace(-10, 10, 20) + cdf_approx = inferer._approx_standard_normal_cdf(x) + cdf_true = norm.cdf(x) + torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + + @parameterized.expand(CNDM_TEST_CASES) + def test_sampler_conditioned_concat(self, model_params, controlnet_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet = ControlNet(**controlnet_params) + controlnet.to(device) + controlnet.eval() + noise = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(len(intermediates), 10) + +class LCN_TestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_prediction_shape( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, controlnet_params, + input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + seg=input_seg, + noise=noise, + timesteps=timesteps, + ) + else: + prediction = inferer( + inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps, + controlnet=controlnet, cn_cond=mask, + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_sample_shape( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, controlnet_params, + input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler, + controlnet=controlnet, cn_cond=mask, + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_sample_intermediates( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, controlnet_params, + input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + controlnet=controlnet, + cn_cond=mask, + ) + else: + sample, intermediates = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + controlnet=controlnet, + cn_cond=mask, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, input_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_get_likelihoods( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, controlnet_params, + input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, latent_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_resample_likelihoods( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, controlnet_params, + input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + resample_latent_likelihoods=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + resample_latent_likelihoods=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_prediction_shape_conditioned_concat( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, controlnet_params, + input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + controlnet=controlnet, + cn_cond=mask, + timesteps=timesteps, + condition=conditioning, + mode="concat", + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + controlnet=controlnet, + cn_cond=mask, + timesteps=timesteps, + condition=conditioning, + mode="concat", + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_sample_shape_conditioned_concat( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, controlnet_params, + input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES) + def test_sample_shape_different_latents( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, controlnet_params, + input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["num_channels"]) - 1)) for i in input_shape[2:]] + inferer = ControlNetLatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, + ) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + noise=noise, + timesteps=timesteps, + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, + controlnet=controlnet, cn_cond=mask, timesteps=timesteps + ) + self.assertEqual(prediction.shape, latent_shape) + + def test_incompatible_spade_setup(self): + stage_1 = SPADEAutoencoderKL( + spatial_dims=2, + label_nc=6, + in_channels=1, + out_channels=1, + num_channels=(4, 4), + latent_channels=3, + attention_levels=[False, False], + num_res_blocks=1, + with_encoder_nonlocal_attn=False, + with_decoder_nonlocal_attn=False, + norm_num_groups=4, + ) + stage_2 = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=3, + out_channels=3, + num_channels=[4, 4], + norm_num_groups=4, + attention_levels=[False, False], + num_res_blocks=1, + num_head_channels=4, + ) + controlnet = ControlNet( + spatial_dims=2, + in_channels=1, + num_channels=[4, 4], + norm_num_groups=4, + attention_levels=[False, False], + num_res_blocks=1, + num_head_channels=4, + conditioning_embedding_num_channels=[16,], + ) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + noise = torch.randn((1, 3, 4, 4)).to(device) + mask = torch.randn((1, 1, 4, 4)).to(device) + input_seg = torch.randn((1, 3, 8, 8)).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + with self.assertRaises(ValueError): + _ = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + seg=input_seg, + ) + + +if __name__ == "__main__": + unittest.main() From 10d646c93fdbea7ae8b270c8a6c7d6a87980685c Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Thu, 30 Nov 2023 15:24:40 +0000 Subject: [PATCH 2/5] Changed isinstance() ifs to partial to make it less messy on the SPADE check. Fixed some formatting typos. Deleted two functions and changed inheritance of ControlNetDiffusionInferer. Changed names of tests to agree with caps convention. --- generative/inferers/inferer.py | 265 ++++++++++-------------------- tests/test_controlnet_inferers.py | 4 +- 2 files changed, 85 insertions(+), 184 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 196c9aea..533234af 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -19,7 +19,7 @@ from monai.inferers import Inferer from monai.transforms import CenterSpatialCrop, SpatialPad from monai.utils import optional_import - +from functools import partial from generative.networks.nets import SPADEAutoencoderKL, SPADEDiffusionModelUNet tqdm, has_tqdm = optional_import("tqdm", name="tqdm") @@ -30,7 +30,6 @@ class DiffusionInferer(Inferer): DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass for a training iteration, and sample from the model. - Args: scheduler: diffusion scheduler. """ @@ -69,10 +68,9 @@ def __call__( if mode == "concat": noisy_image = torch.cat([noisy_image, condition], dim=1) condition = None - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - prediction = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition, seg=seg) - else: - prediction = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition) + diffusion_model = partial(diffusion_model, seg = seg) if \ + isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model + prediction = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition) return prediction @@ -116,23 +114,15 @@ def sample( # 1. predict noise model_output if mode == "concat": model_input = torch.cat([image, conditioning], dim=1) - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - model_output = diffusion_model( - model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None, seg=seg - ) - else: - model_output = diffusion_model( - model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None - ) + diffusion_model = partial(diffusion_model, seg=seg) if \ + isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model + model_output = diffusion_model( + model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None) else: - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - model_output = diffusion_model( - image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning, seg=seg - ) - else: - model_output = diffusion_model( - image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning - ) + diffusion_model = partial(diffusion_model, seg=seg) if \ + isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model + model_output = diffusion_model( + image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning) # 2. compute previous image: x_t -> x_t-1 image, _ = scheduler.step(model_output, t, image) @@ -194,15 +184,14 @@ def get_likelihood( noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) if mode == "concat": noisy_image = torch.cat([noisy_image, conditioning], dim=1) - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None, seg=seg) - else: - model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None) + diffusion_model = partial(diffusion_model, seg=seg) if \ + isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model + model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None) else: - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning, seg=seg) - else: - model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) + diffusion_model = partial(diffusion_model, seg=seg) if \ + isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model + model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) + # get the model's predicted mean, and variance if it is predicted if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) @@ -443,29 +432,19 @@ def sample( "labels for each must be compatible. " ) - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - outputs = super().sample( - input_noise=input_noise, - diffusion_model=diffusion_model, - scheduler=scheduler, - save_intermediates=save_intermediates, - intermediate_steps=intermediate_steps, - conditioning=conditioning, - mode=mode, - verbose=verbose, - seg=seg, - ) - else: - outputs = super().sample( - input_noise=input_noise, - diffusion_model=diffusion_model, - scheduler=scheduler, - save_intermediates=save_intermediates, - intermediate_steps=intermediate_steps, - conditioning=conditioning, - mode=mode, - verbose=verbose, - ) + sample = partial(super().sample, seg=seg) if \ + isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().sample + + outputs = sample( + input_noise=input_noise, + diffusion_model=diffusion_model, + scheduler=scheduler, + save_intermediates=save_intermediates, + intermediate_steps=intermediate_steps, + conditioning=conditioning, + mode=mode, + verbose=verbose + ) if save_intermediates: latent, latent_intermediates = outputs @@ -541,27 +520,20 @@ def get_likelihood( if self.ldm_latent_shape is not None: latents = self.ldm_resizer(latents) - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - outputs = super().get_likelihood( - inputs=latents, - diffusion_model=diffusion_model, - scheduler=scheduler, - save_intermediates=save_intermediates, - conditioning=conditioning, - mode=mode, - verbose=verbose, - seg=seg, - ) - else: - outputs = super().get_likelihood( - inputs=latents, - diffusion_model=diffusion_model, - scheduler=scheduler, - save_intermediates=save_intermediates, - conditioning=conditioning, - mode=mode, - verbose=verbose, - ) + + get_likelihood = partial(super().get_likelihood, seg=seg) if \ + isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().get_likelihood + + outputs = get_likelihood( + inputs=latents, + diffusion_model=diffusion_model, + scheduler=scheduler, + save_intermediates=save_intermediates, + conditioning=conditioning, + mode=mode, + verbose=verbose + ) + if save_intermediates and resample_latent_likelihoods: intermediates = outputs[1] resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) @@ -569,10 +541,11 @@ def get_likelihood( outputs = (outputs[0], intermediates) return outputs -class ControlNetDiffusionInferer(Inferer): +class ControlNetDiffusionInferer(DiffusionInferer): """ ControlNetDiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass for a training iteration, and sample from the model, supporting ControlNet-based conditioning. + Args: scheduler: diffusion scheduler. """ @@ -621,16 +594,13 @@ def __call__( noisy_image = torch.cat([noisy_image, condition], dim=1) condition = None - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - prediction = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition, seg=seg, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample - ) - else: - prediction = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample - ) + diffusion_model = partial(diffusion_model, seg=seg) if \ + isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model + + prediction = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample + ) return prediction @@ -683,31 +653,22 @@ def sample( # 2. predict noise model_output if mode == "concat": model_input = torch.cat([image, conditioning], dim=1) - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - model_output = diffusion_model( - model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None, seg=seg, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample - ) - else: - model_output = diffusion_model( - model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample - ) + diffusion_model = partial(diffusion_model, seg=seg) if \ + isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model + + model_output = diffusion_model( + model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample + ) else: - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - model_output = diffusion_model( - image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning, seg=seg, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample - ) - else: - model_output = diffusion_model( - image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample - ) + diffusion_model = partial(diffusion_model, seg=seg) if \ + isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model + model_output = diffusion_model( + image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample + ) # 3. compute previous image: x_t -> x_t-1 image, _ = scheduler.step(model_output, t, image) @@ -778,27 +739,19 @@ def get_likelihood( if mode == "concat": noisy_image = torch.cat([noisy_image, conditioning], dim=1) - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None, seg=seg, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample - ) - else: - model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample - ) + diffusion_model = partial(diffusion_model, seg=seg) if \ + isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model + model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample + ) else: - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning, seg=seg, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample - ) - else: - model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample - ) + diffusion_model = partial(diffusion_model, seg=seg) if \ + isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model + model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample + ) # get the model's predicted mean, and variance if it is predicted if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) @@ -841,7 +794,7 @@ def get_likelihood( if t == 0: # compute -log p(x_0|x_1) - kl = -self._get_decoder_log_likelihood( + kl = -super()._get_decoder_log_likelihood( inputs=inputs, means=predicted_mean, log_scales=0.5 * log_predicted_variance, @@ -866,58 +819,6 @@ def get_likelihood( else: return total_kl - def _approx_standard_normal_cdf(self, x): - """ - A fast approximation of the cumulative distribution function of the - standard normal. Code adapted from https://github.com/openai/improved-diffusion. - """ - - return 0.5 * ( - 1.0 + torch.tanh( - torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3))) - ) - - def _get_decoder_log_likelihood( - self, - inputs: torch.Tensor, - means: torch.Tensor, - log_scales: torch.Tensor, - original_input_range: tuple | None = (0, 255), - scaled_input_range: tuple | None = (0, 1), - ) -> torch.Tensor: - """ - Compute the log-likelihood of a Gaussian distribution discretizing to a - given image. Code adapted from https://github.com/openai/improved-diffusion. - - Args: - input: the target images. It is assumed that this was uint8 values, - rescaled to the range [-1, 1]. - means: the Gaussian mean Tensor. - log_scales: the Gaussian log stddev Tensor. - original_input_range: the [min,max] intensity range of the input data before any scaling was applied. - scaled_input_range: the [min,max] intensity range of the input data after scaling. - """ - assert inputs.shape == means.shape - bin_width = (scaled_input_range[1] - scaled_input_range[0]) / ( - original_input_range[1] - original_input_range[0] - ) - centered_x = inputs - means - inv_stdv = torch.exp(-log_scales) - plus_in = inv_stdv * (centered_x + bin_width / 2) - cdf_plus = self._approx_standard_normal_cdf(plus_in) - min_in = inv_stdv * (centered_x - bin_width / 2) - cdf_min = self._approx_standard_normal_cdf(min_in) - log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) - log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) - cdf_delta = cdf_plus - cdf_min - log_probs = torch.where( - inputs < -0.999, - log_cdf_plus, - torch.where(inputs > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), - ) - assert log_probs.shape == inputs.shape - return log_probs - class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer): """ ControlNetLatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, controlnet, diff --git a/tests/test_controlnet_inferers.py b/tests/test_controlnet_inferers.py index 6f14175b..50971125 100644 --- a/tests/test_controlnet_inferers.py +++ b/tests/test_controlnet_inferers.py @@ -430,7 +430,7 @@ ], ] -class CN_TestDiffusionSamplingInferer(unittest.TestCase): +class ControlNetTestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(CNDM_TEST_CASES) def test_call(self, model_params, controlnet_params, input_shape): model = DiffusionModelUNet(**model_params) @@ -606,7 +606,7 @@ def test_sampler_conditioned_concat(self, model_params, controlnet_params, input ) self.assertEqual(len(intermediates), 10) -class LCN_TestDiffusionSamplingInferer(unittest.TestCase): +class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(LATENT_CNDM_TEST_CASES) def test_prediction_shape( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, controlnet_params, From 83d6206888134bd5ef6095a6cd1d57d206d90a4e Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Thu, 30 Nov 2023 15:25:57 +0000 Subject: [PATCH 3/5] Changed isinstance() ifs to partial to make it less messy on the SPADE check. Fixed some formatting typos. Deleted two functions and changed inheritance of ControlNetDiffusionInferer. Changed names of tests to agree with caps convention. + run autofix --- generative/inferers/__init__.py | 9 +- generative/inferers/inferer.py | 265 +++++++++++++++---------- generative/networks/nets/controlnet.py | 4 +- tests/test_controlnet_inferers.py | 190 +++++++++++++----- 4 files changed, 305 insertions(+), 163 deletions(-) diff --git a/generative/inferers/__init__.py b/generative/inferers/__init__.py index 49c195b6..92bbe69f 100644 --- a/generative/inferers/__init__.py +++ b/generative/inferers/__init__.py @@ -11,5 +11,10 @@ from __future__ import annotations -from .inferer import DiffusionInferer, LatentDiffusionInferer, VQVAETransformerInferer, \ - ControlNetDiffusionInferer, ControlNetLatentDiffusionInferer +from .inferer import ( + ControlNetDiffusionInferer, + ControlNetLatentDiffusionInferer, + DiffusionInferer, + LatentDiffusionInferer, + VQVAETransformerInferer, +) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 533234af..3cad8a3d 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -13,13 +13,15 @@ import math from collections.abc import Callable, Sequence +from functools import partial + import torch import torch.nn as nn import torch.nn.functional as F from monai.inferers import Inferer from monai.transforms import CenterSpatialCrop, SpatialPad from monai.utils import optional_import -from functools import partial + from generative.networks.nets import SPADEAutoencoderKL, SPADEDiffusionModelUNet tqdm, has_tqdm = optional_import("tqdm", name="tqdm") @@ -68,8 +70,11 @@ def __call__( if mode == "concat": noisy_image = torch.cat([noisy_image, condition], dim=1) condition = None - diffusion_model = partial(diffusion_model, seg = seg) if \ - isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) prediction = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition) return prediction @@ -114,15 +119,23 @@ def sample( # 1. predict noise model_output if mode == "concat": model_input = torch.cat([image, conditioning], dim=1) - diffusion_model = partial(diffusion_model, seg=seg) if \ - isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) model_output = diffusion_model( - model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None) + model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None + ) else: - diffusion_model = partial(diffusion_model, seg=seg) if \ - isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) model_output = diffusion_model( - image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning) + image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning + ) # 2. compute previous image: x_t -> x_t-1 image, _ = scheduler.step(model_output, t, image) @@ -184,12 +197,18 @@ def get_likelihood( noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) if mode == "concat": noisy_image = torch.cat([noisy_image, conditioning], dim=1) - diffusion_model = partial(diffusion_model, seg=seg) if \ - isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None) else: - diffusion_model = partial(diffusion_model, seg=seg) if \ - isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) # get the model's predicted mean, and variance if it is predicted @@ -310,6 +329,7 @@ def _get_decoder_log_likelihood( assert log_probs.shape == inputs.shape return log_probs + class LatentDiffusionInferer(DiffusionInferer): """ LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can @@ -432,8 +452,9 @@ def sample( "labels for each must be compatible. " ) - sample = partial(super().sample, seg=seg) if \ - isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().sample + sample = ( + partial(super().sample, seg=seg) if isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().sample + ) outputs = sample( input_noise=input_noise, @@ -443,7 +464,7 @@ def sample( intermediate_steps=intermediate_steps, conditioning=conditioning, mode=mode, - verbose=verbose + verbose=verbose, ) if save_intermediates: @@ -520,9 +541,11 @@ def get_likelihood( if self.ldm_latent_shape is not None: latents = self.ldm_resizer(latents) - - get_likelihood = partial(super().get_likelihood, seg=seg) if \ - isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().get_likelihood + get_likelihood = ( + partial(super().get_likelihood, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else super().get_likelihood + ) outputs = get_likelihood( inputs=latents, @@ -531,7 +554,7 @@ def get_likelihood( save_intermediates=save_intermediates, conditioning=conditioning, mode=mode, - verbose=verbose + verbose=verbose, ) if save_intermediates and resample_latent_likelihoods: @@ -541,32 +564,32 @@ def get_likelihood( outputs = (outputs[0], intermediates) return outputs + class ControlNetDiffusionInferer(DiffusionInferer): """ - ControlNetDiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal - forward pass for a training iteration, and sample from the model, supporting ControlNet-based conditioning. + ControlNetDiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal + forward pass for a training iteration, and sample from the model, supporting ControlNet-based conditioning. - Args: - scheduler: diffusion scheduler. - """ + Args: + scheduler: diffusion scheduler. + """ def __init__(self, scheduler: nn.Module) -> None: Inferer.__init__(self) self.scheduler = scheduler def __call__( - self, - inputs: torch.Tensor, - diffusion_model: Callable[..., torch.Tensor], - controlnet: Callable[..., torch.Tensor], - noise: torch.Tensor, - timesteps: torch.Tensor, - cn_cond: torch.Tensor, - condition: torch.Tensor | None = None, - mode: str = "crossattn", - seg: torch.Tensor | None = None, + self, + inputs: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + controlnet: Callable[..., torch.Tensor], + noise: torch.Tensor, + timesteps: torch.Tensor, + cn_cond: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, ) -> torch.Tensor: - """ Implements the forward pass for a supervised training iteration. @@ -586,38 +609,43 @@ def __call__( raise NotImplementedError(f"{mode} condition is not supported") noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) - down_block_res_samples, mid_block_res_sample = controlnet(x=noisy_image, - timesteps=timesteps, - controlnet_cond=cn_cond, - ) + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond + ) if mode == "concat": noisy_image = torch.cat([noisy_image, condition], dim=1) condition = None - diffusion_model = partial(diffusion_model, seg=seg) if \ - isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) - prediction = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample - ) + prediction = diffusion_model( + x=noisy_image, + timesteps=timesteps, + context=condition, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) return prediction @torch.no_grad() def sample( - self, - input_noise: torch.Tensor, - diffusion_model: Callable[..., torch.Tensor], - controlnet: Callable[..., torch.Tensor], - cn_cond: torch.Tensor, - scheduler: Callable[..., torch.Tensor] | None = None, - save_intermediates: bool | None = False, - intermediate_steps: int | None = 100, - conditioning: torch.Tensor | None = None, - mode: str = "crossattn", - verbose: bool = True, - seg: torch.Tensor | None = None, + self, + input_noise: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + controlnet: Callable[..., torch.Tensor], + cn_cond: torch.Tensor, + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -647,27 +675,36 @@ def sample( for t in progress_bar: # 1. ControlNet forward down_block_res_samples, mid_block_res_sample = controlnet( - x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), - controlnet_cond= cn_cond, + x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond ) # 2. predict noise model_output if mode == "concat": model_input = torch.cat([image, conditioning], dim=1) - diffusion_model = partial(diffusion_model, seg=seg) if \ - isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) model_output = diffusion_model( - model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None, + model_input, + timesteps=torch.Tensor((t,)).to(input_noise.device), + context=None, down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample + mid_block_additional_residual=mid_block_res_sample, ) else: - diffusion_model = partial(diffusion_model, seg=seg) if \ - isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) model_output = diffusion_model( - image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning, + image, + timesteps=torch.Tensor((t,)).to(input_noise.device), + context=conditioning, down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample + mid_block_additional_residual=mid_block_res_sample, ) # 3. compute previous image: x_t -> x_t-1 @@ -681,19 +718,19 @@ def sample( @torch.no_grad() def get_likelihood( - self, - inputs: torch.Tensor, - diffusion_model: Callable[..., torch.Tensor], - controlnet: Callable[..., torch.Tensor], - cn_cond: torch.Tensor, - scheduler: Callable[..., torch.Tensor] | None = None, - save_intermediates: bool | None = False, - conditioning: torch.Tensor | None = None, - mode: str = "crossattn", - original_input_range: tuple | None = (0, 255), - scaled_input_range: tuple | None = (0, 1), - verbose: bool = True, - seg: torch.Tensor | None = None, + self, + inputs: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + controlnet: Callable[..., torch.Tensor], + cn_cond: torch.Tensor, + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, + seg: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Computes the log-likelihoods for an input. @@ -733,25 +770,36 @@ def get_likelihood( timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) down_block_res_samples, mid_block_res_sample = controlnet( - x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), - controlnet_cond=cn_cond, + x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond ) if mode == "concat": noisy_image = torch.cat([noisy_image, conditioning], dim=1) - diffusion_model = partial(diffusion_model, seg=seg) if \ - isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model - model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample - ) + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) + model_output = diffusion_model( + noisy_image, + timesteps=timesteps, + context=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) else: - diffusion_model = partial(diffusion_model, seg=seg) if \ - isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model - model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample - ) + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) + model_output = diffusion_model( + x=noisy_image, + timesteps=timesteps, + context=conditioning, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) # get the model's predicted mean, and variance if it is predicted if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) @@ -771,7 +819,7 @@ def get_likelihood( elif scheduler.prediction_type == "sample": pred_original_sample = model_output elif scheduler.prediction_type == "v_prediction": - pred_original_sample = (alpha_prod_t ** 0.5) * noisy_image - (beta_prod_t ** 0.5) * model_output + pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output # 3. Clip "predicted x_0" if scheduler.clip_sample: pred_original_sample = torch.clamp(pred_original_sample, -1, 1) @@ -804,11 +852,11 @@ def get_likelihood( else: # compute kl between two normals kl = 0.5 * ( - -1.0 - + log_predicted_variance - - log_posterior_variance - + torch.exp(log_posterior_variance - log_predicted_variance) - + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) ) total_kl += kl.view(kl.shape[0], -1).mean(axis=1) if save_intermediates: @@ -819,6 +867,7 @@ def get_likelihood( else: return total_kl + class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer): """ ControlNetLatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, controlnet, @@ -882,21 +931,19 @@ def __call__( with torch.no_grad(): latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor - if self.ldm_latent_shape is not None: latent = self.ldm_resizer(latent) if cn_cond.shape[2:] != latent.shape[2:]: cn_cond = F.interpolate(cn_cond, latent.shape[2:]) - if isinstance(diffusion_model, SPADEDiffusionModelUNet): prediction = super().__call__( inputs=latent, diffusion_model=diffusion_model, - controlnet = controlnet, + controlnet=controlnet, noise=noise, timesteps=timesteps, - cn_cond = cn_cond, + cn_cond=cn_cond, condition=condition, mode=mode, seg=seg, @@ -905,10 +952,10 @@ def __call__( prediction = super().__call__( inputs=latent, diffusion_model=diffusion_model, - controlnet = controlnet, + controlnet=controlnet, noise=noise, timesteps=timesteps, - cn_cond = cn_cond, + cn_cond=cn_cond, condition=condition, mode=mode, ) @@ -965,8 +1012,8 @@ def sample( outputs = super().sample( input_noise=input_noise, diffusion_model=diffusion_model, - controlnet = controlnet, - cn_cond = cn_cond, + controlnet=controlnet, + cn_cond=cn_cond, scheduler=scheduler, save_intermediates=save_intermediates, intermediate_steps=intermediate_steps, diff --git a/generative/networks/nets/controlnet.py b/generative/networks/nets/controlnet.py index e2e664d6..ebe2459c 100644 --- a/generative/networks/nets/controlnet.py +++ b/generative/networks/nets/controlnet.py @@ -175,9 +175,7 @@ def __init__( "when using with_conditioning." ) if cross_attention_dim is not None and with_conditioning is False: - raise ValueError( - "ControlNet expects with_conditioning=True when specifying the cross_attention_dim." - ) + raise ValueError("ControlNet expects with_conditioning=True when specifying the cross_attention_dim.") # All number of channels should be multiple of num_groups if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): diff --git a/tests/test_controlnet_inferers.py b/tests/test_controlnet_inferers.py index 50971125..c38eb4c8 100644 --- a/tests/test_controlnet_inferers.py +++ b/tests/test_controlnet_inferers.py @@ -15,9 +15,16 @@ import torch from parameterized import parameterized + from generative.inferers import ControlNetDiffusionInferer, ControlNetLatentDiffusionInferer -from generative.networks.nets import DiffusionModelUNet, ControlNet, AutoencoderKL, \ - SPADEAutoencoderKL, SPADEDiffusionModelUNet, VQVAE +from generative.networks.nets import ( + VQVAE, + AutoencoderKL, + ControlNet, + DiffusionModelUNet, + SPADEAutoencoderKL, + SPADEDiffusionModelUNet, +) from generative.networks.schedulers import DDIMScheduler, DDPMScheduler CNDM_TEST_CASES = [ @@ -30,7 +37,7 @@ "norm_num_groups": 8, "attention_levels": [True], "num_res_blocks": 1, - "num_head_channels": 8 + "num_head_channels": 8, }, { "spatial_dims": 2, @@ -40,8 +47,8 @@ "norm_num_groups": 8, "num_res_blocks": 1, "num_head_channels": 8, - "conditioning_embedding_num_channels": [16,], - "conditioning_embedding_in_channels": 1 + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, }, (2, 1, 8, 8), ], @@ -64,7 +71,7 @@ "num_res_blocks": 1, "norm_num_groups": 8, "num_head_channels": 8, - "conditioning_embedding_num_channels": [16,], + "conditioning_embedding_num_channels": [16], "conditioning_embedding_in_channels": 1, }, (2, 1, 8, 8, 8), @@ -104,8 +111,8 @@ "num_res_blocks": 1, "norm_num_groups": 4, "num_head_channels": 4, - "conditioning_embedding_num_channels": [16,], - "conditioning_embedding_in_channels": 1 + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, }, (1, 1, 8, 8), (1, 3, 4, 4), @@ -143,7 +150,7 @@ "num_res_blocks": 1, "norm_num_groups": 8, "num_head_channels": 8, - "conditioning_embedding_num_channels": [16,], + "conditioning_embedding_num_channels": [16], "conditioning_embedding_in_channels": 1, }, (1, 1, 16, 16), @@ -182,7 +189,7 @@ "num_res_blocks": 1, "norm_num_groups": 8, "num_head_channels": 8, - "conditioning_embedding_num_channels": [16,], + "conditioning_embedding_num_channels": [16], "conditioning_embedding_in_channels": 1, }, (1, 1, 16, 16, 16), @@ -223,7 +230,7 @@ "num_res_blocks": 1, "norm_num_groups": 4, "num_head_channels": 4, - "conditioning_embedding_num_channels": [16,], + "conditioning_embedding_num_channels": [16], "conditioning_embedding_in_channels": 1, }, (1, 1, 12, 12), @@ -262,7 +269,7 @@ "num_res_blocks": 1, "norm_num_groups": 8, "num_head_channels": 8, - "conditioning_embedding_num_channels": [16,], + "conditioning_embedding_num_channels": [16], "conditioning_embedding_in_channels": 1, }, (1, 1, 12, 12), @@ -301,7 +308,7 @@ "num_res_blocks": 1, "norm_num_groups": 8, "num_head_channels": 8, - "conditioning_embedding_num_channels": [16,], + "conditioning_embedding_num_channels": [16], "conditioning_embedding_in_channels": 1, }, (1, 1, 12, 12, 12), @@ -341,7 +348,7 @@ "num_res_blocks": 1, "norm_num_groups": 4, "num_head_channels": 4, - "conditioning_embedding_num_channels": [16,], + "conditioning_embedding_num_channels": [16], "conditioning_embedding_in_channels": 1, }, (1, 1, 8, 8), @@ -381,7 +388,7 @@ "num_res_blocks": 1, "norm_num_groups": 4, "num_head_channels": 4, - "conditioning_embedding_num_channels": [16,], + "conditioning_embedding_num_channels": [16], "conditioning_embedding_in_channels": 1, }, (1, 1, 8, 8), @@ -422,7 +429,7 @@ "num_res_blocks": 1, "norm_num_groups": 4, "num_head_channels": 4, - "conditioning_embedding_num_channels": [16,], + "conditioning_embedding_num_channels": [16], "conditioning_embedding_in_channels": 1, }, (1, 1, 8, 8), @@ -430,6 +437,7 @@ ], ] + class ControlNetTestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(CNDM_TEST_CASES) def test_call(self, model_params, controlnet_params, input_shape): @@ -447,8 +455,9 @@ def test_call(self, model_params, controlnet_params, input_shape): inferer = ControlNetDiffusionInferer(scheduler=scheduler) scheduler.set_timesteps(num_inference_steps=10) timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() - sample = inferer(inputs=input, noise=noise, diffusion_model=model, controlnet=controlnet, - timesteps=timesteps, cn_cond=mask) + sample = inferer( + inputs=input, noise=noise, diffusion_model=model, controlnet=controlnet, timesteps=timesteps, cn_cond=mask + ) self.assertEqual(sample.shape, input_shape) @parameterized.expand(CNDM_TEST_CASES) @@ -466,8 +475,13 @@ def test_sample_intermediates(self, model_params, controlnet_params, input_shape inferer = ControlNetDiffusionInferer(scheduler=scheduler) scheduler.set_timesteps(num_inference_steps=10) sample, intermediates = inferer.sample( - input_noise=noise, diffusion_model=model, scheduler=scheduler, - controlnet=controlnet, cn_cond=mask, save_intermediates=True, intermediate_steps=1 + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, ) self.assertEqual(len(intermediates), 10) @@ -486,8 +500,13 @@ def test_ddpm_sampler(self, model_params, controlnet_params, input_shape): inferer = ControlNetDiffusionInferer(scheduler=scheduler) scheduler.set_timesteps(num_inference_steps=10) sample, intermediates = inferer.sample( - input_noise=noise, diffusion_model=model, scheduler=scheduler, controlnet=controlnet, cn_cond=mask, - save_intermediates=True, intermediate_steps=1 + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, ) self.assertEqual(len(intermediates), 10) @@ -506,8 +525,13 @@ def test_ddim_sampler(self, model_params, controlnet_params, input_shape): inferer = ControlNetDiffusionInferer(scheduler=scheduler) scheduler.set_timesteps(num_inference_steps=10) sample, intermediates = inferer.sample( - input_noise=noise, diffusion_model=model, scheduler=scheduler, controlnet=controlnet, cn_cond=mask, - save_intermediates=True, intermediate_steps=1 + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, ) self.assertEqual(len(intermediates), 10) @@ -555,14 +579,19 @@ def test_get_likelihood(self, model_params, controlnet_params, input_shape): inferer = ControlNetDiffusionInferer(scheduler=scheduler) scheduler.set_timesteps(num_inference_steps=10) likelihood, intermediates = inferer.get_likelihood( - inputs=input, diffusion_model=model, scheduler=scheduler, controlnet=controlnet, cn_cond=mask, - save_intermediates=True + inputs=input, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, ) self.assertEqual(intermediates[0].shape, input.shape) self.assertEqual(likelihood.shape[0], input.shape[0]) def test_normal_cdf(self): from scipy.stats import norm + scheduler = DDPMScheduler(num_train_timesteps=10) inferer = ControlNetDiffusionInferer(scheduler=scheduler) x = torch.linspace(-10, 10, 20) @@ -606,11 +635,18 @@ def test_sampler_conditioned_concat(self, model_params, controlnet_params, input ) self.assertEqual(len(intermediates), 10) + class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(LATENT_CNDM_TEST_CASES) def test_prediction_shape( - self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, controlnet_params, - input_shape, latent_shape + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, ): if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) @@ -657,15 +693,26 @@ def test_prediction_shape( ) else: prediction = inferer( - inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps, - controlnet=controlnet, cn_cond=mask, + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + controlnet=controlnet, + cn_cond=mask, ) self.assertEqual(prediction.shape, latent_shape) @parameterized.expand(LATENT_CNDM_TEST_CASES) def test_sample_shape( - self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, controlnet_params, - input_shape, latent_shape + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, ): if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) @@ -709,15 +756,25 @@ def test_sample_shape( ) else: sample = inferer.sample( - input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler, - controlnet=controlnet, cn_cond=mask, + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, ) self.assertEqual(sample.shape, input_shape) @parameterized.expand(LATENT_CNDM_TEST_CASES) def test_sample_intermediates( - self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, controlnet_params, - input_shape, latent_shape + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, ): if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) @@ -777,8 +834,14 @@ def test_sample_intermediates( @parameterized.expand(LATENT_CNDM_TEST_CASES) def test_get_likelihoods( - self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, controlnet_params, - input_shape, latent_shape + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, ): if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) @@ -838,8 +901,14 @@ def test_get_likelihoods( @parameterized.expand(LATENT_CNDM_TEST_CASES) def test_resample_likelihoods( - self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, controlnet_params, - input_shape, latent_shape + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, ): if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) @@ -901,8 +970,14 @@ def test_resample_likelihoods( @parameterized.expand(LATENT_CNDM_TEST_CASES) def test_prediction_shape_conditioned_concat( - self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, controlnet_params, - input_shape, latent_shape + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, ): if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) @@ -975,8 +1050,14 @@ def test_prediction_shape_conditioned_concat( @parameterized.expand(LATENT_CNDM_TEST_CASES) def test_sample_shape_conditioned_concat( - self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, controlnet_params, - input_shape, latent_shape + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, ): if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) @@ -1044,8 +1125,14 @@ def test_sample_shape_conditioned_concat( @parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES) def test_sample_shape_different_latents( - self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, controlnet_params, - input_shape, latent_shape + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, ): if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) @@ -1102,8 +1189,13 @@ def test_sample_shape_different_latents( ) else: prediction = inferer( - inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, - controlnet=controlnet, cn_cond=mask, timesteps=timesteps + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + controlnet=controlnet, + cn_cond=mask, + timesteps=timesteps, ) self.assertEqual(prediction.shape, latent_shape) @@ -1140,7 +1232,7 @@ def test_incompatible_spade_setup(self): attention_levels=[False, False], num_res_blocks=1, num_head_channels=4, - conditioning_embedding_num_channels=[16,], + conditioning_embedding_num_channels=[16], ) device = "cuda:0" if torch.cuda.is_available() else "cpu" From 4fe2ef95f30de0d153d3c11f6745d537fa248037 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Thu, 30 Nov 2023 15:51:41 +0000 Subject: [PATCH 4/5] Taken partial call outside concatenation mode check. Added a few partials that had been forgotten last commit. --- generative/inferers/inferer.py | 171 ++++++++++----------------------- 1 file changed, 53 insertions(+), 118 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 3cad8a3d..f83e2eec 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -117,22 +117,17 @@ def sample( intermediates = [] for t in progress_bar: # 1. predict noise model_output + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) if mode == "concat": model_input = torch.cat([image, conditioning], dim=1) - diffusion_model = ( - partial(diffusion_model, seg=seg) - if isinstance(diffusion_model, SPADEDiffusionModelUNet) - else diffusion_model - ) model_output = diffusion_model( model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None ) else: - diffusion_model = ( - partial(diffusion_model, seg=seg) - if isinstance(diffusion_model, SPADEDiffusionModelUNet) - else diffusion_model - ) model_output = diffusion_model( image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning ) @@ -195,20 +190,15 @@ def get_likelihood( for t in progress_bar: timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) if mode == "concat": noisy_image = torch.cat([noisy_image, conditioning], dim=1) - diffusion_model = ( - partial(diffusion_model, seg=seg) - if isinstance(diffusion_model, SPADEDiffusionModelUNet) - else diffusion_model - ) model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None) else: - diffusion_model = ( - partial(diffusion_model, seg=seg) - if isinstance(diffusion_model, SPADEDiffusionModelUNet) - else diffusion_model - ) model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) # get the model's predicted mean, and variance if it is predicted @@ -391,26 +381,16 @@ def __call__( if self.ldm_latent_shape is not None: latent = self.ldm_resizer(latent) - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - prediction = super().__call__( - inputs=latent, - diffusion_model=diffusion_model, - noise=noise, - timesteps=timesteps, - condition=condition, - mode=mode, - seg=seg, - ) - else: - prediction = super().__call__( - inputs=latent, - diffusion_model=diffusion_model, - noise=noise, - timesteps=timesteps, - condition=condition, - mode=mode, - ) - + call = partial(super().__call__, seg = seg) if \ + isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().__call__ + prediction = call( + inputs=latent, + diffusion_model=diffusion_model, + noise=noise, + timesteps=timesteps, + condition=condition, + mode=mode + ) return prediction @torch.no_grad() @@ -483,7 +463,7 @@ def sample( for latent_intermediate in latent_intermediates: if isinstance(autoencoder_model, SPADEAutoencoderKL): intermediates.append( - autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor), seg=seg + autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor, seg=seg) ) else: intermediates.append( @@ -678,14 +658,13 @@ def sample( x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond ) # 2. predict noise model_output + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) if mode == "concat": model_input = torch.cat([image, conditioning], dim=1) - diffusion_model = ( - partial(diffusion_model, seg=seg) - if isinstance(diffusion_model, SPADEDiffusionModelUNet) - else diffusion_model - ) - model_output = diffusion_model( model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), @@ -694,11 +673,6 @@ def sample( mid_block_additional_residual=mid_block_res_sample, ) else: - diffusion_model = ( - partial(diffusion_model, seg=seg) - if isinstance(diffusion_model, SPADEDiffusionModelUNet) - else diffusion_model - ) model_output = diffusion_model( image, timesteps=torch.Tensor((t,)).to(input_noise.device), @@ -773,13 +747,13 @@ def get_likelihood( x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond ) + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) if mode == "concat": noisy_image = torch.cat([noisy_image, conditioning], dim=1) - diffusion_model = ( - partial(diffusion_model, seg=seg) - if isinstance(diffusion_model, SPADEDiffusionModelUNet) - else diffusion_model - ) model_output = diffusion_model( noisy_image, timesteps=timesteps, @@ -788,11 +762,6 @@ def get_likelihood( mid_block_additional_residual=mid_block_res_sample, ) else: - diffusion_model = ( - partial(diffusion_model, seg=seg) - if isinstance(diffusion_model, SPADEDiffusionModelUNet) - else diffusion_model - ) model_output = diffusion_model( x=noisy_image, timesteps=timesteps, @@ -936,8 +905,10 @@ def __call__( if cn_cond.shape[2:] != latent.shape[2:]: cn_cond = F.interpolate(cn_cond, latent.shape[2:]) + call = partial(super().__call__, seg = seg) if \ + isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().__call__ if isinstance(diffusion_model, SPADEDiffusionModelUNet): - prediction = super().__call__( + prediction = call( inputs=latent, diffusion_model=diffusion_model, controlnet=controlnet, @@ -945,19 +916,7 @@ def __call__( timesteps=timesteps, cn_cond=cn_cond, condition=condition, - mode=mode, - seg=seg, - ) - else: - prediction = super().__call__( - inputs=latent, - diffusion_model=diffusion_model, - controlnet=controlnet, - noise=noise, - timesteps=timesteps, - cn_cond=cn_cond, - condition=condition, - mode=mode, + mode=mode ) return prediction @@ -1008,22 +967,9 @@ def sample( if cn_cond.shape[2:] != input_noise.shape[2:]: cn_cond = F.interpolate(cn_cond, input_noise.shape[2:]) - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - outputs = super().sample( - input_noise=input_noise, - diffusion_model=diffusion_model, - controlnet=controlnet, - cn_cond=cn_cond, - scheduler=scheduler, - save_intermediates=save_intermediates, - intermediate_steps=intermediate_steps, - conditioning=conditioning, - mode=mode, - verbose=verbose, - seg=seg, - ) - else: - outputs = super().sample( + sample = partial(super().sample, seg = seg) if \ + isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model + outputs = sample( input_noise=input_noise, diffusion_model=diffusion_model, controlnet=controlnet, @@ -1034,7 +980,7 @@ def sample( conditioning=conditioning, mode=mode, verbose=verbose, - ) + ) if save_intermediates: latent, latent_intermediates = outputs @@ -1118,31 +1064,20 @@ def get_likelihood( if self.ldm_latent_shape is not None: latents = self.ldm_resizer(latents) - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - outputs = super().get_likelihood( - inputs=latents, - diffusion_model=diffusion_model, - controlnet=controlnet, - cn_cond=cn_cond, - scheduler=scheduler, - save_intermediates=save_intermediates, - conditioning=conditioning, - mode=mode, - verbose=verbose, - seg=seg, - ) - else: - outputs = super().get_likelihood( - inputs=latents, - diffusion_model=diffusion_model, - controlnet=controlnet, - cn_cond=cn_cond, - scheduler=scheduler, - save_intermediates=save_intermediates, - conditioning=conditioning, - mode=mode, - verbose=verbose, - ) + get_likelihood = partial(super().get_likelihood, seg = seg) if \ + isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().get_likelihood + outputs = get_likelihood( + inputs=latents, + diffusion_model=diffusion_model, + controlnet=controlnet, + cn_cond=cn_cond, + scheduler=scheduler, + save_intermediates=save_intermediates, + conditioning=conditioning, + mode=mode, + verbose=verbose, + ) + if save_intermediates and resample_latent_likelihoods: intermediates = outputs[1] resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) From bb1de766a91f70546fedbc605e0e2630067274d7 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Fri, 1 Dec 2023 07:50:13 +0000 Subject: [PATCH 5/5] There were two bugs (some functions wrongly named). Tests re-run and work. --- generative/inferers/inferer.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index f83e2eec..261f745b 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -907,17 +907,16 @@ def __call__( call = partial(super().__call__, seg = seg) if \ isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().__call__ - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - prediction = call( - inputs=latent, - diffusion_model=diffusion_model, - controlnet=controlnet, - noise=noise, - timesteps=timesteps, - cn_cond=cn_cond, - condition=condition, - mode=mode - ) + prediction = call( + inputs=latent, + diffusion_model=diffusion_model, + controlnet=controlnet, + noise=noise, + timesteps=timesteps, + cn_cond=cn_cond, + condition=condition, + mode=mode + ) return prediction @@ -968,7 +967,8 @@ def sample( cn_cond = F.interpolate(cn_cond, input_noise.shape[2:]) sample = partial(super().sample, seg = seg) if \ - isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model + isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().sample + outputs = sample( input_noise=input_noise, diffusion_model=diffusion_model,