From fd29603cfdaf3e46f3d9313be211ed31309f5a25 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 20 Dec 2023 16:18:27 +0000 Subject: [PATCH 01/37] Adds inferers. Changes arg in the spade_diffusion_model from num_channels to channels to remain consistent with other networks added Signed-off-by: Mark Graham --- monai/inferers/__init__.py | 4 + monai/inferers/inferer.py | 1279 +++++++++++++++++ .../nets/spade_diffusion_model_unet.py | 40 +- test_spade_diffusion_model_unet.py | 66 +- tests/test_controlnet_inferers.py | 1266 ++++++++++++++++ tests/test_diffusion_inferer.py | 222 +++ tests/test_latent_diffusion_inferer.py | 765 ++++++++++ 7 files changed, 3589 insertions(+), 53 deletions(-) create mode 100644 tests/test_controlnet_inferers.py create mode 100644 tests/test_diffusion_inferer.py create mode 100644 tests/test_latent_diffusion_inferer.py diff --git a/monai/inferers/__init__.py b/monai/inferers/__init__.py index 960380bfb8..8141b3111e 100644 --- a/monai/inferers/__init__.py +++ b/monai/inferers/__init__.py @@ -12,7 +12,11 @@ from __future__ import annotations from .inferer import ( + ControlNetDiffusionInferer, + ControlNetLatentDiffusionInferer, + DiffusionInferer, Inferer, + LatentDiffusionInferer, PatchInferer, SaliencyInferer, SimpleInferer, diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 0b4199938d..35f6465dfd 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -11,24 +11,32 @@ from __future__ import annotations +import math import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence +from functools import partial from pydoc import locate from typing import Any import torch import torch.nn as nn +import torch.nn.functional as F from monai.apps.utils import get_logger +from monai.data import decollate_batch from monai.data.meta_tensor import MetaTensor from monai.data.thread_buffer import ThreadBuffer from monai.inferers.merger import AvgMerger, Merger from monai.inferers.splitter import Splitter from monai.inferers.utils import compute_importance_map, sliding_window_inference +from monai.networks.nets import SPADEAutoencoderKL, SPADEDiffusionModelUNet +from monai.transforms import CenterSpatialCrop, SpatialPad from monai.utils import BlendMode, PatchKeys, PytorchPadMode, ensure_tuple, optional_import from monai.visualize import CAM, GradCAM, GradCAMpp +tqdm, has_tqdm = optional_import("tqdm", name="tqdm") + logger = get_logger(__name__) __all__ = [ @@ -752,3 +760,1274 @@ def network_wrapper( return out return tuple(out_i.unsqueeze(dim=self.spatial_dim + 2) for out_i in out) + + +class DiffusionInferer(Inferer): + """ + DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass + for a training iteration, and sample from the model. + + Args: + scheduler: diffusion scheduler. + """ + + def __init__(self, scheduler: nn.Module) -> None: + Inferer.__init__(self) + self.scheduler = scheduler + + def __call__( + self, + inputs: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + noise: torch.Tensor, + timesteps: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: Input image to which noise is added. + diffusion_model: diffusion model. + noise: random noise, of the same shape as the input. + timesteps: random timesteps. + condition: Conditioning for network input. + mode: Conditioning mode for the network. + seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be + provided on the forward (for SPADE-like AE or SPADE-like DM) + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + if mode == "concat": + noisy_image = torch.cat([noisy_image, condition], dim=1) + condition = None + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) + prediction = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition) + + return prediction + + @torch.no_grad() + def sample( + self, + input_noise: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired sample. + diffusion_model: model to sample from. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + + if not scheduler: + scheduler = self.scheduler + image = input_noise + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + for t in progress_bar: + # 1. predict noise model_output + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) + if mode == "concat": + model_input = torch.cat([image, conditioning], dim=1) + model_output = diffusion_model( + model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None + ) + else: + model_output = diffusion_model( + image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning + ) + + # 2. compute previous image: x_t -> x_t-1 + image, _ = scheduler.step(model_output, t, image) + if save_intermediates and t % intermediate_steps == 0: + intermediates.append(image) + if save_intermediates: + return image, intermediates + else: + return image + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods for an input. + + Args: + inputs: input images, NxCxHxW[xD] + diffusion_model: model to compute likelihood from + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + + if not scheduler: + scheduler = self.scheduler + if scheduler._get_name() != "DDPMScheduler": + raise NotImplementedError( + f"Likelihood computation is only compatible with DDPMScheduler," + f" you are using {scheduler._get_name()}" + ) + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + noise = torch.randn_like(inputs).to(inputs.device) + total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) + for t in progress_bar: + timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) + if mode == "concat": + noisy_image = torch.cat([noisy_image, conditioning], dim=1) + model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None) + else: + model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) + + # get the model's predicted mean, and variance if it is predicted + if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[t] + alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if scheduler.prediction_type == "epsilon": + pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif scheduler.prediction_type == "sample": + pred_original_sample = model_output + elif scheduler.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output + # 3. Clip "predicted x_0" + if scheduler.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t + current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image + + # get the posterior mean and variance + posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) + + log_posterior_variance = torch.log(posterior_variance) + log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + + if t == 0: + # compute -log p(x_0|x_1) + kl = -self._get_decoder_log_likelihood( + inputs=inputs, + means=predicted_mean, + log_scales=0.5 * log_predicted_variance, + original_input_range=original_input_range, + scaled_input_range=scaled_input_range, + ) + else: + # compute kl between two normals + kl = 0.5 * ( + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + ) + total_kl += kl.view(kl.shape[0], -1).mean(axis=1) + if save_intermediates: + intermediates.append(kl.cpu()) + + if save_intermediates: + return total_kl, intermediates + else: + return total_kl + + def _approx_standard_normal_cdf(self, x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. Code adapted from https://github.com/openai/improved-diffusion. + """ + + return 0.5 * ( + 1.0 + torch.tanh(torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3))) + ) + + def _get_decoder_log_likelihood( + self, + inputs: torch.Tensor, + means: torch.Tensor, + log_scales: torch.Tensor, + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + ) -> torch.Tensor: + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. Code adapted from https://github.com/openai/improved-diffusion. + + Args: + input: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + means: the Gaussian mean Tensor. + log_scales: the Gaussian log stddev Tensor. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + """ + assert inputs.shape == means.shape + bin_width = (scaled_input_range[1] - scaled_input_range[0]) / ( + original_input_range[1] - original_input_range[0] + ) + centered_x = inputs - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + bin_width / 2) + cdf_plus = self._approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - bin_width / 2) + cdf_min = self._approx_standard_normal_cdf(min_in) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + inputs < -0.999, + log_cdf_plus, + torch.where(inputs > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == inputs.shape + return log_probs + + +class LatentDiffusionInferer(DiffusionInferer): + """ + LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can + be used to perform a signal forward pass for a training iteration, and sample from the model. + + Args: + scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents. + scale_factor: scale factor to multiply the values of the latent representation before processing it by the + second stage. + ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape. + autoencoder_latent_shape: autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a + difference between the autoencoder's latent shape and the DM shape. + """ + + def __init__( + self, + scheduler: nn.Module, + scale_factor: float = 1.0, + ldm_latent_shape: list | None = None, + autoencoder_latent_shape: list | None = None, + ) -> None: + super().__init__(scheduler=scheduler) + self.scale_factor = scale_factor + if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None): + raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None" "and vice versa.") + self.ldm_latent_shape = ldm_latent_shape + self.autoencoder_latent_shape = autoencoder_latent_shape + if self.ldm_latent_shape is not None: + self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) + self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape) + + def __call__( + self, + inputs: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + noise: torch.Tensor, + timesteps: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: input image to which the latent representation will be extracted and noise is added. + autoencoder_model: first stage model. + diffusion_model: diffusion model. + noise: random noise, of the same shape as the latent representation. + timesteps: random timesteps. + condition: conditioning for network input. + mode: Conditioning mode for the network. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + with torch.no_grad(): + latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if self.ldm_latent_shape is not None: + latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0) + + call = super().__call__ + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + call = partial(super().__call__, seg=seg) + + prediction = call( + inputs=latent, + diffusion_model=diffusion_model, + noise=noise, + timesteps=timesteps, + condition=condition, + mode=mode, + ) + return prediction + + @torch.no_grad() + def sample( + self, + input_noise: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired latent representation. + autoencoder_model: first stage model. + diffusion_model: model to sample from. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + + if ( + isinstance(autoencoder_model, SPADEAutoencoderKL) + and isinstance(diffusion_model, SPADEDiffusionModelUNet) + and autoencoder_model.decoder.label_nc != diffusion_model.label_nc + ): + raise ValueError( + "If both autoencoder_model and diffusion_model implement SPADE, the number of semantic" + "labels for each must be compatible. " + ) + + sample = super().sample + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + sample = partial(super().sample, seg=seg) + + outputs = sample( + input_noise=input_noise, + diffusion_model=diffusion_model, + scheduler=scheduler, + save_intermediates=save_intermediates, + intermediate_steps=intermediate_steps, + conditioning=conditioning, + mode=mode, + verbose=verbose, + ) + + if save_intermediates: + latent, latent_intermediates = outputs + else: + latent = outputs + + if self.autoencoder_latent_shape is not None: + latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + ] + + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + image = decode(latent / self.scale_factor) + + if save_intermediates: + intermediates = [] + for latent_intermediate in latent_intermediates: + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + intermediates.append(decode(latent_intermediate / self.scale_factor)) + return image, intermediates + + else: + return image + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + autoencoder_model: first stage model. + diffusion_model: model to compute likelihood from + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) + latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if self.ldm_latent_shape is not None: + latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) + + get_likelihood = super().get_likelihood + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + get_likelihood = partial(super().get_likelihood, seg=seg) + + outputs = get_likelihood( + inputs=latents, + diffusion_model=diffusion_model, + scheduler=scheduler, + save_intermediates=save_intermediates, + conditioning=conditioning, + mode=mode, + verbose=verbose, + ) + + if save_intermediates and resample_latent_likelihoods: + intermediates = outputs[1] + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) + intermediates = [resizer(x) for x in intermediates] + outputs = (outputs[0], intermediates) + return outputs + + +class ControlNetDiffusionInferer(DiffusionInferer): + """ + ControlNetDiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal + forward pass for a training iteration, and sample from the model, supporting ControlNet-based conditioning. + + Args: + scheduler: diffusion scheduler. + """ + + def __init__(self, scheduler: nn.Module) -> None: + Inferer.__init__(self) + self.scheduler = scheduler + + def __call__( + self, + inputs: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + controlnet: Callable[..., torch.Tensor], + noise: torch.Tensor, + timesteps: torch.Tensor, + cn_cond: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: Input image to which noise is added. + diffusion_model: diffusion model. + controlnet: controlnet sub-network. + noise: random noise, of the same shape as the input. + timesteps: random timesteps. + cn_cond: conditioning image for the ControlNet. + condition: Conditioning for network input. + mode: Conditioning mode for the network. + seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be + provided on the forward (for SPADE-like AE or SPADE-like DM) + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond + ) + if mode == "concat": + noisy_image = torch.cat([noisy_image, condition], dim=1) + condition = None + + diffuse = diffusion_model + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + diffuse = partial(diffusion_model, seg=seg) + + prediction = diffuse( + x=noisy_image, + timesteps=timesteps, + context=condition, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + + return prediction + + @torch.no_grad() + def sample( + self, + input_noise: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + controlnet: Callable[..., torch.Tensor], + cn_cond: torch.Tensor, + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired sample. + diffusion_model: model to sample from. + controlnet: controlnet sub-network. + cn_cond: conditioning image for the ControlNet. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + + if not scheduler: + scheduler = self.scheduler + image = input_noise + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + for t in progress_bar: + # 1. ControlNet forward + down_block_res_samples, mid_block_res_sample = controlnet( + x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond + ) + # 2. predict noise model_output + diffuse = diffusion_model + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + diffuse = partial(diffusion_model, seg=seg) + + if mode == "concat": + model_input = torch.cat([image, conditioning], dim=1) + model_output = diffuse( + model_input, + timesteps=torch.Tensor((t,)).to(input_noise.device), + context=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + else: + model_output = diffuse( + image, + timesteps=torch.Tensor((t,)).to(input_noise.device), + context=conditioning, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + + # 3. compute previous image: x_t -> x_t-1 + image, _ = scheduler.step(model_output, t, image) + if save_intermediates and t % intermediate_steps == 0: + intermediates.append(image) + if save_intermediates: + return image, intermediates + else: + return image + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + controlnet: Callable[..., torch.Tensor], + cn_cond: torch.Tensor, + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods for an input. + + Args: + inputs: input images, NxCxHxW[xD] + diffusion_model: model to compute likelihood from + controlnet: controlnet sub-network. + cn_cond: conditioning image for the ControlNet. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + + if not scheduler: + scheduler = self.scheduler + if scheduler._get_name() != "DDPMScheduler": + raise NotImplementedError( + f"Likelihood computation is only compatible with DDPMScheduler," + f" you are using {scheduler._get_name()}" + ) + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + noise = torch.randn_like(inputs).to(inputs.device) + total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) + for t in progress_bar: + timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond + ) + + diffuse = diffusion_model + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + diffuse = partial(diffusion_model, seg=seg) + + if mode == "concat": + noisy_image = torch.cat([noisy_image, conditioning], dim=1) + model_output = diffuse( + noisy_image, + timesteps=timesteps, + context=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + else: + model_output = diffuse( + x=noisy_image, + timesteps=timesteps, + context=conditioning, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + # get the model's predicted mean, and variance if it is predicted + if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[t] + alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if scheduler.prediction_type == "epsilon": + pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif scheduler.prediction_type == "sample": + pred_original_sample = model_output + elif scheduler.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output + # 3. Clip "predicted x_0" + if scheduler.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t + current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image + + # get the posterior mean and variance + posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) + + log_posterior_variance = torch.log(posterior_variance) + log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + + if t == 0: + # compute -log p(x_0|x_1) + kl = -super()._get_decoder_log_likelihood( + inputs=inputs, + means=predicted_mean, + log_scales=0.5 * log_predicted_variance, + original_input_range=original_input_range, + scaled_input_range=scaled_input_range, + ) + else: + # compute kl between two normals + kl = 0.5 * ( + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + ) + total_kl += kl.view(kl.shape[0], -1).mean(axis=1) + if save_intermediates: + intermediates.append(kl.cpu()) + + if save_intermediates: + return total_kl, intermediates + else: + return total_kl + + +class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer): + """ + ControlNetLatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, controlnet, + and a scheduler, and can be used to perform a signal forward pass for a training iteration, and sample from + the model. + + Args: + scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents. + scale_factor: scale factor to multiply the values of the latent representation before processing it by the + second stage. + ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape. + autoencoder_latent_shape: autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a + difference between the autoencoder's latent shape and the DM shape. + """ + + def __init__( + self, + scheduler: nn.Module, + scale_factor: float = 1.0, + ldm_latent_shape: list | None = None, + autoencoder_latent_shape: list | None = None, + ) -> None: + super().__init__(scheduler=scheduler) + self.scale_factor = scale_factor + if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None): + raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None" "and vice versa.") + self.ldm_latent_shape = ldm_latent_shape + self.autoencoder_latent_shape = autoencoder_latent_shape + if self.ldm_latent_shape is not None: + self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) + self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1] + self.autoencoder_latent_shape) + + def __call__( + self, + inputs: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + controlnet: Callable[..., torch.Tensor], + noise: torch.Tensor, + timesteps: torch.Tensor, + cn_cond: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: input image to which the latent representation will be extracted and noise is added. + autoencoder_model: first stage model. + diffusion_model: diffusion model. + controlnet: instance of ControlNet model + noise: random noise, of the same shape as the latent representation. + timesteps: random timesteps. + cn_cond: conditioning tensor for the ControlNet network + condition: conditioning for network input. + mode: Conditioning mode for the network. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + with torch.no_grad(): + latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if self.ldm_latent_shape is not None: + latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0) + + if cn_cond.shape[2:] != latent.shape[2:]: + cn_cond = F.interpolate(cn_cond, latent.shape[2:]) + + call = super().__call__ + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + call = partial(super().__call__, seg=seg) + + prediction = call( + inputs=latent, + diffusion_model=diffusion_model, + controlnet=controlnet, + noise=noise, + timesteps=timesteps, + cn_cond=cn_cond, + condition=condition, + mode=mode, + ) + + return prediction + + @torch.no_grad() + def sample( + self, + input_noise: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + controlnet: Callable[..., torch.Tensor], + cn_cond: torch.Tensor, + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired latent representation. + autoencoder_model: first stage model. + diffusion_model: model to sample from. + controlnet: instance of ControlNet model. + cn_cond: conditioning tensor for the ControlNet network. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + + if ( + isinstance(autoencoder_model, SPADEAutoencoderKL) + and isinstance(diffusion_model, SPADEDiffusionModelUNet) + and autoencoder_model.decoder.label_nc != diffusion_model.label_nc + ): + raise ValueError( + "If both autoencoder_model and diffusion_model implement SPADE, the number of semantic" + "labels for each must be compatible. " + ) + + if cn_cond.shape[2:] != input_noise.shape[2:]: + cn_cond = F.interpolate(cn_cond, input_noise.shape[2:]) + + sample = super().sample + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + sample = partial(super().sample, seg=seg) + + outputs = sample( + input_noise=input_noise, + diffusion_model=diffusion_model, + controlnet=controlnet, + cn_cond=cn_cond, + scheduler=scheduler, + save_intermediates=save_intermediates, + intermediate_steps=intermediate_steps, + conditioning=conditioning, + mode=mode, + verbose=verbose, + ) + + if save_intermediates: + latent, latent_intermediates = outputs + else: + latent = outputs + + if self.autoencoder_latent_shape is not None: + latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + ] + + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + + image = decode(latent / self.scale_factor) + + if save_intermediates: + intermediates = [] + for latent_intermediate in latent_intermediates: + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + intermediates.append(decode(latent_intermediate / self.scale_factor)) + return image, intermediates + + else: + return image + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + controlnet: Callable[..., torch.Tensor], + cn_cond: torch.Tensor, + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + autoencoder_model: first stage model. + diffusion_model: model to compute likelihood from + controlnet: instance of ControlNet model. + cn_cond: conditioning tensor for the ControlNet network. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) + + latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if cn_cond.shape[2:] != latents.shape[2:]: + cn_cond = F.interpolate(cn_cond, latents.shape[2:]) + + if self.ldm_latent_shape is not None: + latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) + + get_likelihood = super().get_likelihood + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + get_likelihood = partial(super().get_likelihood, seg=seg) + + outputs = get_likelihood( + inputs=latents, + diffusion_model=diffusion_model, + controlnet=controlnet, + cn_cond=cn_cond, + scheduler=scheduler, + save_intermediates=save_intermediates, + conditioning=conditioning, + mode=mode, + verbose=verbose, + ) + + if save_intermediates and resample_latent_likelihoods: + intermediates = outputs[1] + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) + intermediates = [resizer(x) for x in intermediates] + outputs = (outputs[0], intermediates) + return outputs + + +class VQVAETransformerInferer(Inferer): + """ + Class to perform inference with a VQVAE + Transformer model. + """ + + def __init__(self) -> None: + Inferer.__init__(self) + + def __call__( + self, + inputs: torch.Tensor, + vqvae_model: Callable[..., torch.Tensor], + transformer_model: Callable[..., torch.Tensor], + ordering: Callable[..., torch.Tensor], + condition: torch.Tensor | None = None, + return_latent: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, tuple]: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: input image to which the latent representation will be extracted. + vqvae_model: first stage model. + transformer_model: autoregressive transformer model. + ordering: ordering of the quantised latent representation. + return_latent: also return latent sequence and spatial dim of the latent. + condition: conditioning for network input. + """ + with torch.no_grad(): + latent = vqvae_model.index_quantize(inputs) + + latent_spatial_dim = tuple(latent.shape[1:]) + latent = latent.reshape(latent.shape[0], -1) + latent = latent[:, ordering.get_sequence_ordering()] + + # get the targets for the loss + target = latent.clone() + # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token. + # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens. + latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) + # crop the last token as we do not need the probability of the token that follows it + latent = latent[:, :-1] + latent = latent.long() + + # train on a part of the sequence if it is longer than max_seq_length + seq_len = latent.shape[1] + max_seq_len = transformer_model.max_seq_len + if max_seq_len < seq_len: + start = torch.randint(low=0, high=seq_len + 1 - max_seq_len, size=(1,)).item() + else: + start = 0 + prediction = transformer_model(x=latent[:, start : start + max_seq_len], context=condition) + if return_latent: + return prediction, target[:, start : start + max_seq_len], latent_spatial_dim + else: + return prediction + + @torch.no_grad() + def sample( + self, + latent_spatial_dim: Sequence[int, int, int] | Sequence[int, int], + starting_tokens: torch.Tensor, + vqvae_model: Callable[..., torch.Tensor], + transformer_model: Callable[..., torch.Tensor], + ordering: Callable[..., torch.Tensor], + conditioning: torch.Tensor | None = None, + temperature: float = 1.0, + top_k: int | None = None, + verbose: bool = True, + ) -> torch.Tensor: + """ + Sampling function for the VQVAE + Transformer model. + + Args: + latent_spatial_dim: shape of the sampled image. + starting_tokens: starting tokens for the sampling. It must be vqvae_model.num_embeddings value. + vqvae_model: first stage model. + transformer_model: model to sample from. + conditioning: Conditioning for network input. + temperature: temperature for sampling. + top_k: top k sampling. + verbose: if true, prints the progression bar of the sampling process. + """ + seq_len = math.prod(latent_spatial_dim) + + if verbose and has_tqdm: + progress_bar = tqdm(range(seq_len)) + else: + progress_bar = iter(range(seq_len)) + + latent_seq = starting_tokens.long() + for _ in progress_bar: + # if the sequence context is growing too long we must crop it at block_size + if latent_seq.size(1) <= transformer_model.max_seq_len: + idx_cond = latent_seq + else: + idx_cond = latent_seq[:, -transformer_model.max_seq_len :] + + # forward the model to get the logits for the index in the sequence + logits = transformer_model(x=idx_cond, context=conditioning) + # pluck the logits at the final step and scale by desired temperature + logits = logits[:, -1, :] / temperature + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float("Inf") + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + # remove the chance to be sampled the BOS token + probs[:, vqvae_model.num_embeddings] = 0 + # sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) + # append sampled index to the running sequence and continue + latent_seq = torch.cat((latent_seq, idx_next), dim=1) + + latent_seq = latent_seq[:, 1:] + latent_seq = latent_seq[:, ordering.get_revert_sequence_ordering()] + latent = latent_seq.reshape((starting_tokens.shape[0],) + latent_spatial_dim) + + return vqvae_model.decode_samples(latent) + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + vqvae_model: Callable[..., torch.Tensor], + transformer_model: Callable[..., torch.Tensor], + ordering: Callable[..., torch.Tensor], + condition: torch.Tensor | None = None, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", + verbose: bool = False, + ) -> torch.Tensor: + """ + Computes the log-likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + vqvae_model: first stage model. + transformer_model: autoregressive transformer model. + ordering: ordering of the quantised latent representation. + condition: conditioning for network input. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; + verbose: if true, prints the progression bar of the sampling process. + + """ + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) + + with torch.no_grad(): + latent = vqvae_model.index_quantize(inputs) + + latent_spatial_dim = tuple(latent.shape[1:]) + latent = latent.reshape(latent.shape[0], -1) + latent = latent[:, ordering.get_sequence_ordering()] + seq_len = math.prod(latent_spatial_dim) + + # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token. + # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens. + latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) + latent = latent.long() + + # get the first batch, up to max_seq_length, efficiently + logits = transformer_model(x=latent[:, : transformer_model.max_seq_len], context=condition) + probs = F.softmax(logits, dim=-1) + # target token for each set of logits is the next token along + target = latent[:, 1:] + probs = torch.gather(probs, 2, target[:, : transformer_model.max_seq_len].unsqueeze(2)).squeeze(2) + + # if we have not covered the full sequence we continue with inefficient looping + if probs.shape[1] < target.shape[1]: + if verbose and has_tqdm: + progress_bar = tqdm(range(transformer_model.max_seq_len, seq_len)) + else: + progress_bar = iter(range(transformer_model.max_seq_len, seq_len)) + + for i in progress_bar: + idx_cond = latent[:, i + 1 - transformer_model.max_seq_len : i + 1] + # forward the model to get the logits for the index in the sequence + logits = transformer_model(x=idx_cond, context=condition) + # pluck the logits at the final step + logits = logits[:, -1, :] + # apply softmax to convert logits to (normalized) probabilities + p = F.softmax(logits, dim=-1) + # select correct values and append + p = torch.gather(p, 1, target[:, i].unsqueeze(1)) + + probs = torch.cat((probs, p), dim=1) + + # convert to log-likelihood + probs = torch.log(probs) + + # reshape + probs = probs[:, ordering.get_revert_sequence_ordering()] + probs_reshaped = probs.reshape((inputs.shape[0],) + latent_spatial_dim) + if resample_latent_likelihoods: + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) + probs_reshaped = resizer(probs_reshaped[:, None, ...]) + + return probs_reshaped diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py index d53327100e..bffc9c5465 100644 --- a/monai/networks/nets/spade_diffusion_model_unet.py +++ b/monai/networks/nets/spade_diffusion_model_unet.py @@ -618,7 +618,7 @@ class SPADEDiffusionModelUNet(nn.Module): out_channels: number of output channels. label_nc: number of semantic channels for SPADE normalisation. num_res_blocks: number of residual blocks (see ResnetBlock) per level. - num_channels: tuple of block output channels. + channels: tuple of block output channels. attention_levels: list of levels to add attention. norm_num_groups: number of groups for the normalization. norm_eps: epsilon for the normalization. @@ -641,7 +641,7 @@ def __init__( out_channels: int, label_nc: int, num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), - num_channels: Sequence[int] = (32, 64, 64, 64), + channels: Sequence[int] = (32, 64, 64, 64), attention_levels: Sequence[bool] = (False, False, True, True), norm_num_groups: int = 32, norm_eps: float = 1e-6, @@ -667,10 +667,10 @@ def __init__( ) # All number of channels should be multiple of num_groups - if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): raise ValueError("SPADEDiffusionModelUNet expects all num_channels being multiple of norm_num_groups") - if len(num_channels) != len(attention_levels): + if len(channels) != len(attention_levels): raise ValueError("SPADEDiffusionModelUNet expects num_channels being same size of attention_levels") if isinstance(num_head_channels, int): @@ -683,9 +683,9 @@ def __init__( ) if isinstance(num_res_blocks, int): - num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) - if len(num_res_blocks) != len(num_channels): + if len(num_res_blocks) != len(channels): raise ValueError( "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " "`num_channels`." @@ -700,7 +700,7 @@ def __init__( ) self.in_channels = in_channels - self.block_out_channels = num_channels + self.block_out_channels = channels self.out_channels = out_channels self.num_res_blocks = num_res_blocks self.attention_levels = attention_levels @@ -712,7 +712,7 @@ def __init__( self.conv_in = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, - out_channels=num_channels[0], + out_channels=channels[0], strides=1, kernel_size=3, padding=1, @@ -720,9 +720,9 @@ def __init__( ) # time - time_embed_dim = num_channels[0] * 4 + time_embed_dim = channels[0] * 4 self.time_embed = nn.Sequential( - nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) ) # class embedding @@ -732,11 +732,11 @@ def __init__( # down self.down_blocks = nn.ModuleList([]) - output_channel = num_channels[0] - for i in range(len(num_channels)): + output_channel = channels[0] + for i in range(len(channels)): input_channel = output_channel - output_channel = num_channels[i] - is_final_block = i == len(num_channels) - 1 + output_channel = channels[i] + is_final_block = i == len(channels) - 1 down_block = get_down_block( spatial_dims=spatial_dims, @@ -762,7 +762,7 @@ def __init__( # mid self.middle_block = get_mid_block( spatial_dims=spatial_dims, - in_channels=num_channels[-1], + in_channels=channels[-1], temb_channels=time_embed_dim, norm_num_groups=norm_num_groups, norm_eps=norm_eps, @@ -776,7 +776,7 @@ def __init__( # up self.up_blocks = nn.ModuleList([]) - reversed_block_out_channels = list(reversed(num_channels)) + reversed_block_out_channels = list(reversed(channels)) reversed_num_res_blocks = list(reversed(num_res_blocks)) reversed_attention_levels = list(reversed(attention_levels)) reversed_num_head_channels = list(reversed(num_head_channels)) @@ -784,9 +784,9 @@ def __init__( for i in range(len(reversed_block_out_channels)): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)] + input_channel = reversed_block_out_channels[min(i + 1, len(channels) - 1)] - is_final_block = i == len(num_channels) - 1 + is_final_block = i == len(channels) - 1 up_block = get_spade_up_block( spatial_dims=spatial_dims, @@ -814,12 +814,12 @@ def __init__( # out self.out = nn.Sequential( - nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True), + nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[0], eps=norm_eps, affine=True), nn.SiLU(), zero_module( Convolution( spatial_dims=spatial_dims, - in_channels=num_channels[0], + in_channels=channels[0], out_channels=out_channels, strides=1, kernel_size=3, diff --git a/test_spade_diffusion_model_unet.py b/test_spade_diffusion_model_unet.py index c8a2103cf6..113e58ed89 100644 --- a/test_spade_diffusion_model_unet.py +++ b/test_spade_diffusion_model_unet.py @@ -26,7 +26,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "label_nc": 3, @@ -38,7 +38,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": (1, 1, 2), - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "label_nc": 3, @@ -50,7 +50,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "resblock_updown": True, @@ -63,7 +63,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, @@ -76,7 +76,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, @@ -90,7 +90,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 4, "norm_num_groups": 8, @@ -103,7 +103,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, True, True), "num_head_channels": (0, 2, 4), "norm_num_groups": 8, @@ -119,7 +119,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "label_nc": 3, @@ -132,7 +132,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "label_nc": 3, @@ -144,7 +144,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "resblock_updown": True, @@ -157,7 +157,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, @@ -170,7 +170,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, @@ -184,7 +184,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 4, "norm_num_groups": 8, @@ -197,7 +197,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": (0, 0, 4), "norm_num_groups": 8, @@ -213,7 +213,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 4, "norm_num_groups": 8, @@ -229,7 +229,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 4, "norm_num_groups": 8, @@ -246,7 +246,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 4, "norm_num_groups": 8, @@ -279,7 +279,7 @@ def test_timestep_with_wrong_shape(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, False), norm_num_groups=8, ) @@ -296,7 +296,7 @@ def test_label_with_wrong_shape(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, False), norm_num_groups=8, ) @@ -313,7 +313,7 @@ def test_shape_with_different_in_channel_out_channel(self): in_channels=in_channels, out_channels=out_channels, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, False), norm_num_groups=8, ) @@ -331,7 +331,7 @@ def test_model_channels_not_multiple_of_norm_num_group(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 12), + channels=(8, 8, 12), attention_levels=(False, False, False), norm_num_groups=8, ) @@ -344,13 +344,13 @@ def test_attention_levels_with_different_length_num_head_channels(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, False), num_head_channels=(0, 2), norm_num_groups=8, ) - def test_num_res_blocks_with_different_length_num_channels(self): + def test_num_res_blocks_with_different_length_channels(self): with self.assertRaises(ValueError): SPADEDiffusionModelUNet( spatial_dims=2, @@ -358,7 +358,7 @@ def test_num_res_blocks_with_different_length_num_channels(self): in_channels=1, out_channels=1, num_res_blocks=(1, 1), - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, False), norm_num_groups=8, ) @@ -370,7 +370,7 @@ def test_shape_conditioned_models(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), with_conditioning=True, transformer_num_layers=1, @@ -395,7 +395,7 @@ def test_with_conditioning_cross_attention_dim_none(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), with_conditioning=True, transformer_num_layers=1, @@ -410,7 +410,7 @@ def test_context_with_conditioning_none(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), with_conditioning=False, transformer_num_layers=1, @@ -433,7 +433,7 @@ def test_shape_conditioned_models_class_conditioning(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), norm_num_groups=8, num_head_channels=8, @@ -455,7 +455,7 @@ def test_conditioned_models_no_class_labels(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), norm_num_groups=8, num_head_channels=8, @@ -469,7 +469,7 @@ def test_conditioned_models_no_class_labels(self): seg=torch.rand((1, 3, 16, 32)), ) - def test_model_num_channels_not_same_size_of_attention_levels(self): + def test_model_channels_not_same_size_of_attention_levels(self): with self.assertRaises(ValueError): SPADEDiffusionModelUNet( spatial_dims=2, @@ -477,7 +477,7 @@ def test_model_num_channels_not_same_size_of_attention_levels(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False), norm_num_groups=8, num_head_channels=8, @@ -518,7 +518,7 @@ def test_shape_with_different_in_channel_out_channel(self): in_channels=in_channels, out_channels=out_channels, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), norm_num_groups=4, ) @@ -537,7 +537,7 @@ def test_shape_conditioned_models(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(16, 16, 16), + channels=(16, 16, 16), attention_levels=(False, False, True), norm_num_groups=16, with_conditioning=True, diff --git a/tests/test_controlnet_inferers.py b/tests/test_controlnet_inferers.py new file mode 100644 index 0000000000..8df61b6cde --- /dev/null +++ b/tests/test_controlnet_inferers.py @@ -0,0 +1,1266 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.inferers import ControlNetDiffusionInferer, ControlNetLatentDiffusionInferer +from monai.networks.nets import ( + VQVAE, + AutoencoderKL, + ControlNet, + DiffusionModelUNet, + SPADEAutoencoderKL, + SPADEDiffusionModelUNet, +) +from monai.networks.schedulers import DDIMScheduler, DDPMScheduler + +CNDM_TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 2, + "in_channels": 1, + "channels": [8], + "attention_levels": [True], + "norm_num_groups": 8, + "num_res_blocks": 1, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (2, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 3, + "in_channels": 1, + "channels": [8], + "attention_levels": [True], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (2, 1, 8, 8, 8), + ], +] +LATENT_CNDM_TEST_CASES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 16, 16), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 3, + "in_channels": 3, + "channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 16, 16, 16), + (1, 3, 4, 4, 4), + ], +] +LATENT_CNDM_TEST_CASES_DIFF_SHAPES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 3, + "in_channels": 3, + "channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 12, 12, 12), + (1, 3, 8, 8, 8), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], +] + + +class ControlNetTestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(CNDM_TEST_CASES) + def test_call(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer( + inputs=input, noise=noise, diffusion_model=model, controlnet=controlnet, timesteps=timesteps, cn_cond=mask + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(CNDM_TEST_CASES) + def test_sample_intermediates(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + noise = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_ddpm_sampler(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_ddim_sampler(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_sampler_conditioned(self, model_params, controlnet_params, input_shape): + model_params["with_conditioning"] = True + model_params["cross_attention_dim"] = 3 + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + conditioning = torch.randn([input_shape[0], 1, 3]).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_get_likelihood(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet = ControlNet(**controlnet_params) + controlnet.to(device) + controlnet.eval() + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + likelihood, intermediates = inferer.get_likelihood( + inputs=input, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + ) + self.assertEqual(intermediates[0].shape, input.shape) + self.assertEqual(likelihood.shape[0], input.shape[0]) + + def test_normal_cdf(self): + from scipy.stats import norm + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + x = torch.linspace(-10, 10, 20) + cdf_approx = inferer._approx_standard_normal_cdf(x) + cdf_true = norm.cdf(x) + torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + + @parameterized.expand(CNDM_TEST_CASES) + def test_sampler_conditioned_concat(self, model_params, controlnet_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet = ControlNet(**controlnet_params) + controlnet.to(device) + controlnet.eval() + noise = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(len(intermediates), 10) + + +class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_prediction_shape( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + seg=input_seg, + noise=noise, + timesteps=timesteps, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + controlnet=controlnet, + cn_cond=mask, + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_sample_shape( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_sample_intermediates( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + controlnet=controlnet, + cn_cond=mask, + ) + else: + sample, intermediates = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + controlnet=controlnet, + cn_cond=mask, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, input_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_get_likelihoods( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, latent_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_resample_likelihoods( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + resample_latent_likelihoods=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + resample_latent_likelihoods=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_prediction_shape_conditioned_concat( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + controlnet=controlnet, + cn_cond=mask, + timesteps=timesteps, + condition=conditioning, + mode="concat", + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + controlnet=controlnet, + cn_cond=mask, + timesteps=timesteps, + condition=conditioning, + mode="concat", + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_sample_shape_conditioned_concat( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES) + def test_sample_shape_different_latents( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] + inferer = ControlNetLatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, + ) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + noise=noise, + timesteps=timesteps, + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + controlnet=controlnet, + cn_cond=mask, + timesteps=timesteps, + ) + self.assertEqual(prediction.shape, latent_shape) + + def test_incompatible_spade_setup(self): + stage_1 = SPADEAutoencoderKL( + spatial_dims=2, + label_nc=6, + in_channels=1, + out_channels=1, + channels=(4, 4), + latent_channels=3, + attention_levels=[False, False], + num_res_blocks=1, + with_encoder_nonlocal_attn=False, + with_decoder_nonlocal_attn=False, + norm_num_groups=4, + ) + stage_2 = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=3, + out_channels=3, + channels=[4, 4], + norm_num_groups=4, + attention_levels=[False, False], + num_res_blocks=1, + num_head_channels=4, + ) + controlnet = ControlNet( + spatial_dims=2, + in_channels=1, + channels=[4, 4], + norm_num_groups=4, + attention_levels=[False, False], + num_res_blocks=1, + num_head_channels=4, + conditioning_embedding_num_channels=[16], + ) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + noise = torch.randn((1, 3, 4, 4)).to(device) + mask = torch.randn((1, 1, 4, 4)).to(device) + input_seg = torch.randn((1, 3, 8, 8)).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + with self.assertRaises(ValueError): + _ = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + seg=input_seg, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py new file mode 100644 index 0000000000..03a7badc20 --- /dev/null +++ b/tests/test_diffusion_inferer.py @@ -0,0 +1,222 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.inferers import DiffusionInferer +from monai.networks.nets import DiffusionModelUNet +from monai.networks.schedulers import DDIMScheduler, DDPMScheduler + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8, 8), + ], +] + + +class TestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_call(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer(inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(TEST_CASES) + def test_sample_intermediates(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_ddpm_sampler(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_ddim_sampler(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_sampler_conditioned(self, model_params, input_shape): + model_params["with_conditioning"] = True + model_params["cross_attention_dim"] = 3 + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + conditioning = torch.randn([input_shape[0], 1, 3]).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_get_likelihood(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + likelihood, intermediates = inferer.get_likelihood( + inputs=input, diffusion_model=model, scheduler=scheduler, save_intermediates=True + ) + self.assertEqual(intermediates[0].shape, input.shape) + self.assertEqual(likelihood.shape[0], input.shape[0]) + + def test_normal_cdf(self): + from scipy.stats import norm + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + + x = torch.linspace(-10, 10, 20) + cdf_approx = inferer._approx_standard_normal_cdf(x) + cdf_true = norm.cdf(x) + torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + + @parameterized.expand(TEST_CASES) + def test_sampler_conditioned_concat(self, model_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_call_conditioned_concat(self, model_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer( + inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps, condition=conditioning, mode="concat" + ) + self.assertEqual(sample.shape, input_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py new file mode 100644 index 0000000000..402af17c7a --- /dev/null +++ b/tests/test_latent_diffusion_inferer.py @@ -0,0 +1,765 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.inferers import LatentDiffusionInferer +from monai.networks.nets import VQVAE, AutoencoderKL, DiffusionModelUNet, SPADEAutoencoderKL, SPADEDiffusionModelUNet +from monai.networks.schedulers import DDPMScheduler + +TEST_CASES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 16, 16), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 16, 16, 16), + (1, 3, 4, 4, 4), + ], +] +TEST_CASES_DIFF_SHAPES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 12, 12, 12), + (1, 3, 8, 8, 8), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], +] + + +class TestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_prediction_shape( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + seg=input_seg, + noise=noise, + timesteps=timesteps, + ) + else: + prediction = inferer( + inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_sample_shape( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(TEST_CASES) + def test_sample_intermediates( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + ) + else: + sample, intermediates = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, input_shape) + + @parameterized.expand(TEST_CASES) + def test_get_likelihoods( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_resample_likelihoods( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + resample_latent_likelihoods=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + resample_latent_likelihoods=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) + + @parameterized.expand(TEST_CASES) + def test_prediction_shape_conditioned_concat( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + condition=conditioning, + mode="concat", + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + condition=conditioning, + mode="concat", + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_sample_shape_conditioned_concat( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(TEST_CASES_DIFF_SHAPES) + def test_sample_shape_different_latents( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] + inferer = LatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, + ) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps + ) + self.assertEqual(prediction.shape, latent_shape) + + def test_incompatible_spade_setup(self): + stage_1 = SPADEAutoencoderKL( + spatial_dims=2, + label_nc=6, + in_channels=1, + out_channels=1, + channels=(4, 4), + latent_channels=3, + attention_levels=[False, False], + num_res_blocks=1, + with_encoder_nonlocal_attn=False, + with_decoder_nonlocal_attn=False, + norm_num_groups=4, + ) + stage_2 = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=3, + out_channels=3, + channels=[4, 4], + norm_num_groups=4, + attention_levels=[False, False], + num_res_blocks=1, + num_head_channels=4, + ) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + noise = torch.randn((1, 3, 4, 4)).to(device) + input_seg = torch.randn((1, 3, 8, 8)).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + with self.assertRaises(ValueError): + _ = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + ) + + +if __name__ == "__main__": + unittest.main() From 495758f01b8b95a25ea969c12e91c7818e6d7640 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 20 Dec 2023 16:24:05 +0000 Subject: [PATCH 02/37] Updates docs --- docs/source/inferers.rst | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/docs/source/inferers.rst b/docs/source/inferers.rst index 33f9e14d83..326f56e96c 100644 --- a/docs/source/inferers.rst +++ b/docs/source/inferers.rst @@ -49,6 +49,29 @@ Inferers :members: :special-members: __call__ +`DiffusionInferer` +~~~~~~~~~~~~~~~~~~ +.. autoclass:: DiffusionInferer + :members: + :special-members: __call__ + +`LatentDiffusionInferer` +~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: LatentDiffusionInferer + :members: + :special-members: __call__ + +`ControlNetDiffusionInferer` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ControlNetDiffusionInferer + :members: + :special-members: __call__ + +`ControlNetLatentDiffusionInferer` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ControlNetLatentDiffusionInferer + :members: + :special-members: __call__ Splitters --------- From 3ebaf9f5c01f105f923903aa18e8b7b357637829 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 21 Dec 2023 10:22:19 +0000 Subject: [PATCH 03/37] Start to address mypy issues, inc changing base class from Inferers --- monai/inferers/inferer.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 35f6465dfd..6a09334bb5 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -30,7 +30,9 @@ from monai.inferers.merger import AvgMerger, Merger from monai.inferers.splitter import Splitter from monai.inferers.utils import compute_importance_map, sliding_window_inference -from monai.networks.nets import SPADEAutoencoderKL, SPADEDiffusionModelUNet +from monai.networks.nets import VQVAE, AutoencoderKL, DiffusionModelUNet, SPADEAutoencoderKL, SPADEDiffusionModelUNet + +# from monai.networks.schedulers import Scheduler from monai.transforms import CenterSpatialCrop, SpatialPad from monai.utils import BlendMode, PatchKeys, PytorchPadMode, ensure_tuple, optional_import from monai.visualize import CAM, GradCAM, GradCAMpp @@ -762,7 +764,7 @@ def network_wrapper( return tuple(out_i.unsqueeze(dim=self.spatial_dim + 2) for out_i in out) -class DiffusionInferer(Inferer): +class DiffusionInferer(nn.Module): """ DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass for a training iteration, and sample from the model. @@ -778,7 +780,7 @@ def __init__(self, scheduler: nn.Module) -> None: def __call__( self, inputs: torch.Tensor, - diffusion_model: Callable[..., torch.Tensor], + diffusion_model: DiffusionModelUNet, noise: torch.Tensor, timesteps: torch.Tensor, condition: torch.Tensor | None = None, @@ -801,16 +803,19 @@ def __call__( if mode not in ["crossattn", "concat"]: raise NotImplementedError(f"{mode} condition is not supported") - noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + noisy_image: torch.Tensor = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) if mode == "concat": - noisy_image = torch.cat([noisy_image, condition], dim=1) - condition = None + if condition is None: + raise ValueError("Conditioning is required for concat condition") + else: + noisy_image = torch.cat([noisy_image, condition], dim=1) + condition = None diffusion_model = ( partial(diffusion_model, seg=seg) if isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model ) - prediction = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition) + prediction: torch.Tensor = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition) return prediction @@ -818,8 +823,8 @@ def __call__( def sample( self, input_noise: torch.Tensor, - diffusion_model: Callable[..., torch.Tensor], - scheduler: Callable[..., torch.Tensor] | None = None, + diffusion_model: DiffusionModelUNet, + scheduler: Scheduler | None = None, save_intermediates: bool | None = False, intermediate_steps: int | None = 100, conditioning: torch.Tensor | None = None, @@ -880,7 +885,7 @@ def sample( def get_likelihood( self, inputs: torch.Tensor, - diffusion_model: Callable[..., torch.Tensor], + diffusion_model: DiffusionModelUNet, scheduler: Callable[..., torch.Tensor] | None = None, save_intermediates: bool | None = False, conditioning: torch.Tensor | None = None, @@ -1281,7 +1286,7 @@ def get_likelihood( return outputs -class ControlNetDiffusionInferer(DiffusionInferer): +class ControlNetDiffusionInferer(nn.Module): """ ControlNetDiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass for a training iteration, and sample from the model, supporting ControlNet-based conditioning. @@ -1826,7 +1831,7 @@ def get_likelihood( return outputs -class VQVAETransformerInferer(Inferer): +class VQVAETransformerInferer(nn.Module): """ Class to perform inference with a VQVAE + Transformer model. """ From 8fef41c9d423002fb27c8365813109f7ca55ae0d Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 4 Jan 2024 15:52:06 +0000 Subject: [PATCH 04/37] Address more mypy Signed-off-by: Mark Graham --- monai/inferers/inferer.py | 106 ++++++++++++++++++++------------------ 1 file changed, 55 insertions(+), 51 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 6a09334bb5..137aba5eaa 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -30,8 +30,8 @@ from monai.inferers.merger import AvgMerger, Merger from monai.inferers.splitter import Splitter from monai.inferers.utils import compute_importance_map, sliding_window_inference -from monai.networks.nets import VQVAE, AutoencoderKL, DiffusionModelUNet, SPADEAutoencoderKL, SPADEDiffusionModelUNet - +from monai.networks.nets import VQVAE, AutoencoderKL, DiffusionModelUNet, SPADEAutoencoderKL, SPADEDiffusionModelUNet, ControlNet, DecoderOnlyTransformer +from monai.networks.schedulers import Scheduler # from monai.networks.schedulers import Scheduler from monai.transforms import CenterSpatialCrop, SpatialPad from monai.utils import BlendMode, PatchKeys, PytorchPadMode, ensure_tuple, optional_import @@ -773,8 +773,9 @@ class DiffusionInferer(nn.Module): scheduler: diffusion scheduler. """ - def __init__(self, scheduler: nn.Module) -> None: - Inferer.__init__(self) + def __init__(self, scheduler: Scheduler) -> None: + super().__init__() + self.scheduler = scheduler def __call__( @@ -846,7 +847,8 @@ def sample( """ if mode not in ["crossattn", "concat"]: raise NotImplementedError(f"{mode} condition is not supported") - + if mode == "concat" and conditioning is None: + raise ValueError("Conditioning must be supplied for if condition mode is concat.") if not scheduler: scheduler = self.scheduler image = input_noise @@ -862,7 +864,7 @@ def sample( if isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model ) - if mode == "concat": + if mode == "concat" and conditioning is not None: model_input = torch.cat([image, conditioning], dim=1) model_output = diffusion_model( model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None @@ -886,12 +888,12 @@ def get_likelihood( self, inputs: torch.Tensor, diffusion_model: DiffusionModelUNet, - scheduler: Callable[..., torch.Tensor] | None = None, + scheduler: Scheduler | None = None, save_intermediates: bool | None = False, conditioning: torch.Tensor | None = None, mode: str = "crossattn", - original_input_range: tuple | None = (0, 255), - scaled_input_range: tuple | None = (0, 1), + original_input_range: tuple = (0, 255), + scaled_input_range: tuple = (0, 1), verbose: bool = True, seg: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: @@ -920,6 +922,8 @@ def get_likelihood( ) if mode not in ["crossattn", "concat"]: raise NotImplementedError(f"{mode} condition is not supported") + if mode == "concat" and conditioning is None: + raise ValueError("Conditioning must be supplied for if condition mode is concat.") if verbose and has_tqdm: progress_bar = tqdm(scheduler.timesteps) else: @@ -935,7 +939,7 @@ def get_likelihood( if isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model ) - if mode == "concat": + if mode == "concat" and conditioning is not None: noisy_image = torch.cat([noisy_image, conditioning], dim=1) model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None) else: @@ -999,7 +1003,7 @@ def get_likelihood( + torch.exp(log_posterior_variance - log_predicted_variance) + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) ) - total_kl += kl.view(kl.shape[0], -1).mean(axis=1) + total_kl += kl.view(kl.shape[0], -1).mean(dim=1) if save_intermediates: intermediates.append(kl.cpu()) @@ -1023,8 +1027,8 @@ def _get_decoder_log_likelihood( inputs: torch.Tensor, means: torch.Tensor, log_scales: torch.Tensor, - original_input_range: tuple | None = (0, 255), - scaled_input_range: tuple | None = (0, 1), + original_input_range: tuple = (0, 255), + scaled_input_range: tuple = (0, 1), ) -> torch.Tensor: """ Compute the log-likelihood of a Gaussian distribution discretizing to a @@ -1076,7 +1080,7 @@ class LatentDiffusionInferer(DiffusionInferer): def __init__( self, - scheduler: nn.Module, + scheduler: Scheduler, scale_factor: float = 1.0, ldm_latent_shape: list | None = None, autoencoder_latent_shape: list | None = None, @@ -1084,18 +1088,18 @@ def __init__( super().__init__(scheduler=scheduler) self.scale_factor = scale_factor if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None): - raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None" "and vice versa.") + raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None, and vice versa.") self.ldm_latent_shape = ldm_latent_shape self.autoencoder_latent_shape = autoencoder_latent_shape - if self.ldm_latent_shape is not None: + if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None: self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape) def __call__( self, inputs: torch.Tensor, - autoencoder_model: Callable[..., torch.Tensor], - diffusion_model: Callable[..., torch.Tensor], + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, noise: torch.Tensor, timesteps: torch.Tensor, condition: torch.Tensor | None = None, @@ -1125,7 +1129,7 @@ def __call__( if isinstance(diffusion_model, SPADEDiffusionModelUNet): call = partial(super().__call__, seg=seg) - prediction = call( + prediction : torch.Tensor = call( inputs=latent, diffusion_model=diffusion_model, noise=noise, @@ -1139,9 +1143,9 @@ def __call__( def sample( self, input_noise: torch.Tensor, - autoencoder_model: Callable[..., torch.Tensor], - diffusion_model: Callable[..., torch.Tensor], - scheduler: Callable[..., torch.Tensor] | None = None, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + scheduler: Scheduler | None = None, save_intermediates: bool | None = False, intermediate_steps: int | None = 100, conditioning: torch.Tensor | None = None, @@ -1221,9 +1225,9 @@ def sample( def get_likelihood( self, inputs: torch.Tensor, - autoencoder_model: Callable[..., torch.Tensor], - diffusion_model: Callable[..., torch.Tensor], - scheduler: Callable[..., torch.Tensor] | None = None, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + scheduler: Scheduler | None = None, save_intermediates: bool | None = False, conditioning: torch.Tensor | None = None, mode: str = "crossattn", @@ -1295,15 +1299,15 @@ class ControlNetDiffusionInferer(nn.Module): scheduler: diffusion scheduler. """ - def __init__(self, scheduler: nn.Module) -> None: + def __init__(self, scheduler: Scheduler) -> None: Inferer.__init__(self) self.scheduler = scheduler def __call__( self, inputs: torch.Tensor, - diffusion_model: Callable[..., torch.Tensor], - controlnet: Callable[..., torch.Tensor], + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, noise: torch.Tensor, timesteps: torch.Tensor, cn_cond: torch.Tensor, @@ -1355,10 +1359,10 @@ def __call__( def sample( self, input_noise: torch.Tensor, - diffusion_model: Callable[..., torch.Tensor], - controlnet: Callable[..., torch.Tensor], + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, cn_cond: torch.Tensor, - scheduler: Callable[..., torch.Tensor] | None = None, + scheduler: Scheduler | None = None, save_intermediates: bool | None = False, intermediate_steps: int | None = 100, conditioning: torch.Tensor | None = None, @@ -1432,10 +1436,10 @@ def sample( def get_likelihood( self, inputs: torch.Tensor, - diffusion_model: Callable[..., torch.Tensor], - controlnet: Callable[..., torch.Tensor], + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, cn_cond: torch.Tensor, - scheduler: Callable[..., torch.Tensor] | None = None, + scheduler: Scheduler | None = None, save_intermediates: bool | None = False, conditioning: torch.Tensor | None = None, mode: str = "crossattn", @@ -1591,7 +1595,7 @@ class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer): def __init__( self, - scheduler: nn.Module, + scheduler: Scheduler, scale_factor: float = 1.0, ldm_latent_shape: list | None = None, autoencoder_latent_shape: list | None = None, @@ -1609,9 +1613,9 @@ def __init__( def __call__( self, inputs: torch.Tensor, - autoencoder_model: Callable[..., torch.Tensor], - diffusion_model: Callable[..., torch.Tensor], - controlnet: Callable[..., torch.Tensor], + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, noise: torch.Tensor, timesteps: torch.Tensor, cn_cond: torch.Tensor, @@ -1664,9 +1668,9 @@ def __call__( def sample( self, input_noise: torch.Tensor, - autoencoder_model: Callable[..., torch.Tensor], - diffusion_model: Callable[..., torch.Tensor], - controlnet: Callable[..., torch.Tensor], + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, cn_cond: torch.Tensor, scheduler: Callable[..., torch.Tensor] | None = None, save_intermediates: bool | None = False, @@ -1756,11 +1760,11 @@ def sample( def get_likelihood( self, inputs: torch.Tensor, - autoencoder_model: Callable[..., torch.Tensor], - diffusion_model: Callable[..., torch.Tensor], - controlnet: Callable[..., torch.Tensor], + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, cn_cond: torch.Tensor, - scheduler: Callable[..., torch.Tensor] | None = None, + scheduler: Scheduler | None = None, save_intermediates: bool | None = False, conditioning: torch.Tensor | None = None, mode: str = "crossattn", @@ -1842,8 +1846,8 @@ def __init__(self) -> None: def __call__( self, inputs: torch.Tensor, - vqvae_model: Callable[..., torch.Tensor], - transformer_model: Callable[..., torch.Tensor], + vqvae_model: VQVAE, + transformer_model: DecoderOnlyTransformer, ordering: Callable[..., torch.Tensor], condition: torch.Tensor | None = None, return_latent: bool = False, @@ -1893,8 +1897,8 @@ def sample( self, latent_spatial_dim: Sequence[int, int, int] | Sequence[int, int], starting_tokens: torch.Tensor, - vqvae_model: Callable[..., torch.Tensor], - transformer_model: Callable[..., torch.Tensor], + vqvae_model: VQVAE, + transformer_model: DecoderOnlyTransformer, ordering: Callable[..., torch.Tensor], conditioning: torch.Tensor | None = None, temperature: float = 1.0, @@ -1956,8 +1960,8 @@ def sample( def get_likelihood( self, inputs: torch.Tensor, - vqvae_model: Callable[..., torch.Tensor], - transformer_model: Callable[..., torch.Tensor], + vqvae_model: VQVAE, + transformer_model: DecoderOnlyTransformer, ordering: Callable[..., torch.Tensor], condition: torch.Tensor | None = None, resample_latent_likelihoods: bool = False, From 9f5a903b1101e9a34e08c1f9afcec238f6674db0 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 9 Jan 2024 13:53:57 +0000 Subject: [PATCH 05/37] Inferers mypy compatible Signed-off-by: Mark Graham --- monai/inferers/inferer.py | 71 ++++++++++++++++++++++----------------- monai/utils/__init__.py | 1 + 2 files changed, 41 insertions(+), 31 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 137aba5eaa..897e252cd0 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -30,11 +30,20 @@ from monai.inferers.merger import AvgMerger, Merger from monai.inferers.splitter import Splitter from monai.inferers.utils import compute_importance_map, sliding_window_inference -from monai.networks.nets import VQVAE, AutoencoderKL, DiffusionModelUNet, SPADEAutoencoderKL, SPADEDiffusionModelUNet, ControlNet, DecoderOnlyTransformer +from monai.networks.nets import ( + VQVAE, + AutoencoderKL, + ControlNet, + DecoderOnlyTransformer, + DiffusionModelUNet, + SPADEAutoencoderKL, + SPADEDiffusionModelUNet, +) from monai.networks.schedulers import Scheduler + # from monai.networks.schedulers import Scheduler from monai.transforms import CenterSpatialCrop, SpatialPad -from monai.utils import BlendMode, PatchKeys, PytorchPadMode, ensure_tuple, optional_import +from monai.utils import BlendMode, Ordering, PatchKeys, PytorchPadMode, ensure_tuple, optional_import from monai.visualize import CAM, GradCAM, GradCAMpp tqdm, has_tqdm = optional_import("tqdm", name="tqdm") @@ -764,7 +773,7 @@ def network_wrapper( return tuple(out_i.unsqueeze(dim=self.spatial_dim + 2) for out_i in out) -class DiffusionInferer(nn.Module): +class DiffusionInferer(Inferer): """ DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass for a training iteration, and sample from the model. @@ -773,12 +782,12 @@ class DiffusionInferer(nn.Module): scheduler: diffusion scheduler. """ - def __init__(self, scheduler: Scheduler) -> None: + def __init__(self, scheduler: Scheduler) -> None: # type: ignore[override] super().__init__() self.scheduler = scheduler - def __call__( + def __call__( # type: ignore[override] self, inputs: torch.Tensor, diffusion_model: DiffusionModelUNet, @@ -1095,7 +1104,7 @@ def __init__( self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape) - def __call__( + def __call__( # type: ignore[override] self, inputs: torch.Tensor, autoencoder_model: AutoencoderKL | VQVAE, @@ -1129,7 +1138,7 @@ def __call__( if isinstance(diffusion_model, SPADEDiffusionModelUNet): call = partial(super().__call__, seg=seg) - prediction : torch.Tensor = call( + prediction: torch.Tensor = call( inputs=latent, diffusion_model=diffusion_model, noise=noise, @@ -1140,7 +1149,7 @@ def __call__( return prediction @torch.no_grad() - def sample( + def sample( # type: ignore[override] self, input_noise: torch.Tensor, autoencoder_model: AutoencoderKL | VQVAE, @@ -1222,7 +1231,7 @@ def sample( return image @torch.no_grad() - def get_likelihood( + def get_likelihood( # type: ignore[override] self, inputs: torch.Tensor, autoencoder_model: AutoencoderKL | VQVAE, @@ -1290,7 +1299,7 @@ def get_likelihood( return outputs -class ControlNetDiffusionInferer(nn.Module): +class ControlNetDiffusionInferer(DiffusionInferer): """ ControlNetDiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass for a training iteration, and sample from the model, supporting ControlNet-based conditioning. @@ -1303,7 +1312,7 @@ def __init__(self, scheduler: Scheduler) -> None: Inferer.__init__(self) self.scheduler = scheduler - def __call__( + def __call__( # type: ignore[override] self, inputs: torch.Tensor, diffusion_model: DiffusionModelUNet, @@ -1337,7 +1346,7 @@ def __call__( down_block_res_samples, mid_block_res_sample = controlnet( x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond ) - if mode == "concat": + if mode == "concat" and condition is not None: noisy_image = torch.cat([noisy_image, condition], dim=1) condition = None @@ -1345,7 +1354,7 @@ def __call__( if isinstance(diffusion_model, SPADEDiffusionModelUNet): diffuse = partial(diffusion_model, seg=seg) - prediction = diffuse( + prediction: torch.Tensor = diffuse( x=noisy_image, timesteps=timesteps, context=condition, @@ -1356,7 +1365,7 @@ def __call__( return prediction @torch.no_grad() - def sample( + def sample( # type: ignore[override] self, input_noise: torch.Tensor, diffusion_model: DiffusionModelUNet, @@ -1405,7 +1414,7 @@ def sample( if isinstance(diffusion_model, SPADEDiffusionModelUNet): diffuse = partial(diffusion_model, seg=seg) - if mode == "concat": + if mode == "concat" and conditioning is not None: model_input = torch.cat([image, conditioning], dim=1) model_output = diffuse( model_input, @@ -1433,7 +1442,7 @@ def sample( return image @torch.no_grad() - def get_likelihood( + def get_likelihood( # type: ignore[override] self, inputs: torch.Tensor, diffusion_model: DiffusionModelUNet, @@ -1443,8 +1452,8 @@ def get_likelihood( save_intermediates: bool | None = False, conditioning: torch.Tensor | None = None, mode: str = "crossattn", - original_input_range: tuple | None = (0, 255), - scaled_input_range: tuple | None = (0, 1), + original_input_range: tuple = (0, 255), + scaled_input_range: tuple = (0, 1), verbose: bool = True, seg: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: @@ -1493,7 +1502,7 @@ def get_likelihood( if isinstance(diffusion_model, SPADEDiffusionModelUNet): diffuse = partial(diffusion_model, seg=seg) - if mode == "concat": + if mode == "concat" and conditioning is not None: noisy_image = torch.cat([noisy_image, conditioning], dim=1) model_output = diffuse( noisy_image, @@ -1568,7 +1577,7 @@ def get_likelihood( + torch.exp(log_posterior_variance - log_predicted_variance) + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) ) - total_kl += kl.view(kl.shape[0], -1).mean(axis=1) + total_kl += kl.view(kl.shape[0], -1).mean(dim=1) if save_intermediates: intermediates.append(kl.cpu()) @@ -1606,11 +1615,11 @@ def __init__( raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None" "and vice versa.") self.ldm_latent_shape = ldm_latent_shape self.autoencoder_latent_shape = autoencoder_latent_shape - if self.ldm_latent_shape is not None: + if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None: self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1] + self.autoencoder_latent_shape) - def __call__( + def __call__( # type: ignore[override] self, inputs: torch.Tensor, autoencoder_model: AutoencoderKL | VQVAE, @@ -1665,14 +1674,14 @@ def __call__( return prediction @torch.no_grad() - def sample( + def sample( # type: ignore[override] self, input_noise: torch.Tensor, autoencoder_model: AutoencoderKL | VQVAE, diffusion_model: DiffusionModelUNet, controlnet: ControlNet, cn_cond: torch.Tensor, - scheduler: Callable[..., torch.Tensor] | None = None, + scheduler: Scheduler | None = None, save_intermediates: bool | None = False, intermediate_steps: int | None = 100, conditioning: torch.Tensor | None = None, @@ -1757,7 +1766,7 @@ def sample( return image @torch.no_grad() - def get_likelihood( + def get_likelihood( # type: ignore[override] self, inputs: torch.Tensor, autoencoder_model: AutoencoderKL | VQVAE, @@ -1848,7 +1857,7 @@ def __call__( inputs: torch.Tensor, vqvae_model: VQVAE, transformer_model: DecoderOnlyTransformer, - ordering: Callable[..., torch.Tensor], + ordering: Ordering, condition: torch.Tensor | None = None, return_latent: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, tuple]: @@ -1883,10 +1892,10 @@ def __call__( seq_len = latent.shape[1] max_seq_len = transformer_model.max_seq_len if max_seq_len < seq_len: - start = torch.randint(low=0, high=seq_len + 1 - max_seq_len, size=(1,)).item() + start = int(torch.randint(low=0, high=seq_len + 1 - max_seq_len, size=(1,)).item()) else: start = 0 - prediction = transformer_model(x=latent[:, start : start + max_seq_len], context=condition) + prediction: torch.Tensor = transformer_model(x=latent[:, start : start + max_seq_len], context=condition) if return_latent: return prediction, target[:, start : start + max_seq_len], latent_spatial_dim else: @@ -1895,11 +1904,11 @@ def __call__( @torch.no_grad() def sample( self, - latent_spatial_dim: Sequence[int, int, int] | Sequence[int, int], + latent_spatial_dim: tuple[int, int, int] | tuple[int, int], starting_tokens: torch.Tensor, vqvae_model: VQVAE, transformer_model: DecoderOnlyTransformer, - ordering: Callable[..., torch.Tensor], + ordering: Ordering, conditioning: torch.Tensor | None = None, temperature: float = 1.0, top_k: int | None = None, @@ -1962,7 +1971,7 @@ def get_likelihood( inputs: torch.Tensor, vqvae_model: VQVAE, transformer_model: DecoderOnlyTransformer, - ordering: Callable[..., torch.Tensor], + ordering: Ordering, condition: torch.Tensor | None = None, resample_latent_likelihoods: bool = False, resample_interpolation_mode: str = "nearest", diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 2c32eb2cf4..03fa1ceed1 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -126,6 +126,7 @@ version_leq, ) from .nvtx import Range +from .ordering import Ordering from .profiling import ( PerfContext, ProfileHandler, From 0ded5a3741fb5b80c348bdaea1ed6f22964759c2 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 9 Jan 2024 14:00:39 +0000 Subject: [PATCH 06/37] DCO Remediation Commit for Mark Graham I, Mark Graham , hereby add my Signed-off-by to this commit: 495758f01b8b95a25ea969c12e91c7818e6d7640 I, Mark Graham , hereby add my Signed-off-by to this commit: 3ebaf9f5c01f105f923903aa18e8b7b357637829 Signed-off-by: Mark Graham --- tests/test_diffusion_inferer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py index 03a7badc20..aa5fd6633f 100644 --- a/tests/test_diffusion_inferer.py +++ b/tests/test_diffusion_inferer.py @@ -23,7 +23,7 @@ TEST_CASES = [ [ { - "spatial_dims": 2, + "spatial_dimss": 2, "in_channels": 1, "out_channels": 1, "channels": [8], From 25f06d64bc99ec10313bddbb831aed52483cd840 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 9 Jan 2024 14:00:55 +0000 Subject: [PATCH 07/37] DCO Signed-off-by: Mark Graham --- tests/test_diffusion_inferer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py index aa5fd6633f..03a7badc20 100644 --- a/tests/test_diffusion_inferer.py +++ b/tests/test_diffusion_inferer.py @@ -23,7 +23,7 @@ TEST_CASES = [ [ { - "spatial_dimss": 2, + "spatial_dims": 2, "in_channels": 1, "out_channels": 1, "channels": [8], From 818eb683db9b0dc44e16542fd2da8265bb6a6547 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 9 Jan 2024 14:21:25 +0000 Subject: [PATCH 08/37] Skip test if scipy not installed Signed-off-by: Mark Graham --- tests/test_diffusion_inferer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py index 03a7badc20..ecd4855385 100644 --- a/tests/test_diffusion_inferer.py +++ b/tests/test_diffusion_inferer.py @@ -19,6 +19,9 @@ from monai.inferers import DiffusionInferer from monai.networks.nets import DiffusionModelUNet from monai.networks.schedulers import DDIMScheduler, DDPMScheduler +from monai.utils import optional_import + +_, has_scipy = optional_import("scipy") TEST_CASES = [ [ @@ -150,6 +153,7 @@ def test_get_likelihood(self, model_params, input_shape): self.assertEqual(intermediates[0].shape, input.shape) self.assertEqual(likelihood.shape[0], input.shape[0]) + @unittest.skipUnless(has_scipy, "Requires scipy library.") def test_normal_cdf(self): from scipy.stats import norm From 74d7663796c5d21e870961f42127aab38bcc36df Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 9 Jan 2024 15:41:46 +0000 Subject: [PATCH 09/37] Skip test if scipy not installed Signed-off-by: Mark Graham --- tests/test_controlnet_inferers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_controlnet_inferers.py b/tests/test_controlnet_inferers.py index 8df61b6cde..1f675537dc 100644 --- a/tests/test_controlnet_inferers.py +++ b/tests/test_controlnet_inferers.py @@ -26,6 +26,9 @@ SPADEDiffusionModelUNet, ) from monai.networks.schedulers import DDIMScheduler, DDPMScheduler +from monai.utils import optional_import + +_, has_scipy = optional_import("scipy") CNDM_TEST_CASES = [ [ @@ -589,6 +592,7 @@ def test_get_likelihood(self, model_params, controlnet_params, input_shape): self.assertEqual(intermediates[0].shape, input.shape) self.assertEqual(likelihood.shape[0], input.shape[0]) + @unittest.skipUnless(has_scipy, "Requires scipy library.") def test_normal_cdf(self): from scipy.stats import norm From 6b0b389a962d647158205b4a9d7ee006020dca44 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 9 Jan 2024 15:56:31 +0000 Subject: [PATCH 10/37] Try to correct non-contiguous error Signed-off-by: Mark Graham --- monai/networks/nets/diffusion_model_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 1532215c70..b3e87a9ed1 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -682,7 +682,7 @@ def __init__( ) def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: - h = x + h = x.contiguous() h = self.norm1(h) h = self.nonlinearity(h) From a1e1bdad9f5b0e08fc48eb6f565ddc2a8a87142f Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 9 Jan 2024 16:47:29 +0000 Subject: [PATCH 11/37] Contigous again Signed-off-by: Mark Graham --- monai/networks/nets/diffusion_model_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index b3e87a9ed1..45cbd043de 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -430,7 +430,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: batch, channel, height, width, depth = x.shape # norm - x = self.norm(x) + x = self.norm(x.contiguous()) if self.spatial_dims == 2: x = x.view(batch, channel, height * width).transpose(1, 2) From a15da8330a9ac2758633a0b477dbe07619778e49 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 10 Jan 2024 11:57:48 +0000 Subject: [PATCH 12/37] Adds missing VQVAETranformerInferer tests Signed-off-by: Mark Graham --- monai/inferers/__init__.py | 1 + tests/test_vqvaetransformer_inferer.py | 284 +++++++++++++++++++++++++ 2 files changed, 285 insertions(+) create mode 100644 tests/test_vqvaetransformer_inferer.py diff --git a/monai/inferers/__init__.py b/monai/inferers/__init__.py index 8141b3111e..599848e095 100644 --- a/monai/inferers/__init__.py +++ b/monai/inferers/__init__.py @@ -23,6 +23,7 @@ SliceInferer, SlidingWindowInferer, SlidingWindowInfererAdapt, +VQVAETransformerInferer ) from .merger import AvgMerger, Merger, ZarrAvgMerger from .splitter import SlidingWindowSplitter, Splitter, WSISlidingWindowSplitter diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py new file mode 100644 index 0000000000..1a511d287b --- /dev/null +++ b/tests/test_vqvaetransformer_inferer.py @@ -0,0 +1,284 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.inferers import VQVAETransformerInferer +from monai.networks.nets import VQVAE, DecoderOnlyTransformer +from monai.utils.ordering import Ordering, OrderingType + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (8, 8), + "num_res_channels": (8, 8), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_res_layers": 1, + "num_embeddings": 16, + "embedding_dim": 8, + }, + { + "num_tokens": 16 + 1, + "max_seq_len": 4, + "attn_layers_dim": 4, + "attn_layers_depth": 2, + "attn_layers_heads": 1, + "with_cross_attention": False, + }, + {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 2, "dimensions": (2, 2, 2)}, + (2, 1, 8, 8), + (2, 4, 17), + (2, 2, 2), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (8, 8), + "num_res_channels": (8, 8), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_res_layers": 1, + "num_embeddings": 16, + "embedding_dim": 8, + }, + { + "num_tokens": 16 + 1, + "max_seq_len": 8, + "attn_layers_dim": 4, + "attn_layers_depth": 2, + "attn_layers_heads": 1, + "with_cross_attention": False, + }, + {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 3, "dimensions": (2, 2, 2, 2)}, + (2, 1, 8, 8, 8), + (2, 8, 17), + (2, 2, 2, 2), + ], +] + + +class TestVQVAETransformerInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_prediction_shape( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) + self.assertEqual(prediction.shape, logits_shape) + + @parameterized.expand(TEST_CASES) + def test_prediction_shape_shorter_sequence( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + max_seq_len = 3 + stage_2_params_shorter = dict(stage_2_params) + stage_2_params_shorter["max_seq_len"] = max_seq_len + stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) + cropped_logits_shape = (logits_shape[0], max_seq_len, logits_shape[2]) + self.assertEqual(prediction.shape, cropped_logits_shape) + + def test_sample(self): + stage_1 = VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(8, 8), + num_res_channels=(8, 8), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + num_res_layers=1, + num_embeddings=16, + embedding_dim=8, + ) + stage_2 = DecoderOnlyTransformer( + num_tokens=16 + 1, + max_seq_len=4, + attn_layers_dim=4, + attn_layers_depth=2, + attn_layers_heads=1, + with_cross_attention=False, + ) + ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2)) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + inferer = VQVAETransformerInferer() + + starting_token = 16 # from stage_1 num_embeddings + + sample = inferer.sample( + latent_spatial_dim=(2, 2), + starting_tokens=starting_token * torch.ones((2, 1), device=device), + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + ) + self.assertEqual(sample.shape, (2, 1, 8, 8)) + + def test_sample_shorter_sequence(self): + stage_1 = VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(8, 8), + num_res_channels=(8, 8), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + num_res_layers=1, + num_embeddings=16, + embedding_dim=8, + ) + stage_2 = DecoderOnlyTransformer( + num_tokens=16 + 1, + max_seq_len=2, + attn_layers_dim=4, + attn_layers_depth=2, + attn_layers_heads=1, + with_cross_attention=False, + ) + ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2)) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + inferer = VQVAETransformerInferer() + + starting_token = 16 # from stage_1 num_embeddings + + sample = inferer.sample( + latent_spatial_dim=(2, 2), + starting_tokens=starting_token * torch.ones((2, 1), device=device), + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + ) + self.assertEqual(sample.shape, (2, 1, 8, 8)) + + @parameterized.expand(TEST_CASES) + def test_get_likelihood( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood( + inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering + ) + self.assertEqual(likelihood.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_get_likelihood_shorter_sequence( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + max_seq_len = 3 + stage_2_params_shorter = dict(stage_2_params) + stage_2_params_shorter["max_seq_len"] = max_seq_len + stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood( + inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering + ) + self.assertEqual(likelihood.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_get_likelihood_resampling( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood( + inputs=input, + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + resample_latent_likelihoods=True, + resample_interpolation_mode="nearest", + ) + self.assertEqual(likelihood.shape, input_shape) + + +if __name__ == "__main__": + unittest.main() From 9cb196dd265dbcf2cdb4a0dcde9f31aa2d205525 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 10 Jan 2024 11:58:49 +0000 Subject: [PATCH 13/37] Formatting Signed-off-by: Mark Graham --- monai/inferers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/inferers/__init__.py b/monai/inferers/__init__.py index 599848e095..fc78b9f7c4 100644 --- a/monai/inferers/__init__.py +++ b/monai/inferers/__init__.py @@ -23,7 +23,7 @@ SliceInferer, SlidingWindowInferer, SlidingWindowInfererAdapt, -VQVAETransformerInferer + VQVAETransformerInferer, ) from .merger import AvgMerger, Merger, ZarrAvgMerger from .splitter import SlidingWindowSplitter, Splitter, WSISlidingWindowSplitter From ecc1d7c68c89556c8e9671d4e8001550c919359d Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 10 Jan 2024 12:11:03 +0000 Subject: [PATCH 14/37] Update monai/inferers/inferer.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/inferers/inferer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 897e252cd0..6a7b47c98e 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -41,7 +41,6 @@ ) from monai.networks.schedulers import Scheduler -# from monai.networks.schedulers import Scheduler from monai.transforms import CenterSpatialCrop, SpatialPad from monai.utils import BlendMode, Ordering, PatchKeys, PytorchPadMode, ensure_tuple, optional_import from monai.visualize import CAM, GradCAM, GradCAMpp From f0f53e58f21a6ae7bedf7784add1048e30fe353b Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 10 Jan 2024 13:52:45 +0000 Subject: [PATCH 15/37] Remove unnecessary partial calls, increase test coverage Signed-off-by: Mark Graham --- monai/inferers/inferer.py | 7 ++---- tests/test_latent_diffusion_inferer.py | 33 +++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 6a7b47c98e..13a1161609 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -40,7 +40,6 @@ SPADEDiffusionModelUNet, ) from monai.networks.schedulers import Scheduler - from monai.transforms import CenterSpatialCrop, SpatialPad from monai.utils import BlendMode, Ordering, PatchKeys, PytorchPadMode, ensure_tuple, optional_import from monai.visualize import CAM, GradCAM, GradCAMpp @@ -1134,8 +1133,6 @@ def __call__( # type: ignore[override] latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0) call = super().__call__ - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - call = partial(super().__call__, seg=seg) prediction: torch.Tensor = call( inputs=latent, @@ -1144,6 +1141,7 @@ def __call__( # type: ignore[override] timesteps=timesteps, condition=condition, mode=mode, + seg=seg, ) return prediction @@ -1187,8 +1185,6 @@ def sample( # type: ignore[override] ) sample = super().sample - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - sample = partial(super().sample, seg=seg) outputs = sample( input_noise=input_noise, @@ -1199,6 +1195,7 @@ def sample( # type: ignore[override] conditioning=conditioning, mode=mode, verbose=verbose, + seg=seg, ) if save_intermediates: diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index 402af17c7a..4ab803bb6f 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -105,6 +105,35 @@ (1, 1, 16, 16, 16), (1, 3, 4, 4, 4), ], + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], ] TEST_CASES_DIFF_SHAPES = [ [ @@ -407,12 +436,14 @@ def test_sample_intermediates( else: input_shape_seg[1] = autoencoder_params["label_nc"] input_seg = torch.randn(input_shape_seg).to(device) - sample = inferer.sample( + sample, intermediates = inferer.sample( input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler, seg=input_seg, + save_intermediates=True, + intermediate_steps=1, ) else: sample, intermediates = inferer.sample( From 5c018cf4febf2e8627cd65c2d3b9e0396d61048a Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 10 Jan 2024 14:27:43 +0000 Subject: [PATCH 16/37] Test if changing inferer inheritance affects contiguous error Signed-off-by: Mark Graham --- monai/inferers/inferer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 13a1161609..b61e86031d 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -771,7 +771,7 @@ def network_wrapper( return tuple(out_i.unsqueeze(dim=self.spatial_dim + 2) for out_i in out) -class DiffusionInferer(Inferer): +class DiffusionInferer(nn.Module): """ DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass for a training iteration, and sample from the model. From 86b21e88561c783a5a3e73262c36232f0e59c7df Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 10 Jan 2024 14:43:32 +0000 Subject: [PATCH 17/37] contig Signed-off-by: Mark Graham --- monai/inferers/inferer.py | 2 +- monai/networks/nets/diffusion_model_unet.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index b61e86031d..13a1161609 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -771,7 +771,7 @@ def network_wrapper( return tuple(out_i.unsqueeze(dim=self.spatial_dim + 2) for out_i in out) -class DiffusionInferer(nn.Module): +class DiffusionInferer(Inferer): """ DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass for a training iteration, and sample from the model. diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 45cbd043de..0441cc9cfe 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -1957,7 +1957,7 @@ def forward( h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) # 7. output block - output: torch.Tensor = self.out(h) + output: torch.Tensor = self.out(h.contiguous()) return output From d654216fab5c71ad4045d11aead2ab5487de653b Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 10 Jan 2024 14:56:51 +0000 Subject: [PATCH 18/37] contig Signed-off-by: Mark Graham --- monai/networks/nets/diffusion_model_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 0441cc9cfe..befc7a6c4a 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -682,7 +682,7 @@ def __init__( ) def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: - h = x.contiguous() + h = x#.contiguous() h = self.norm1(h) h = self.nonlinearity(h) From ccc31100e0eb23bc1e03ef6f60db28512e4a9a13 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 10 Jan 2024 15:11:07 +0000 Subject: [PATCH 19/37] undo Signed-off-by: Mark Graham --- monai/networks/nets/diffusion_model_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index befc7a6c4a..0441cc9cfe 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -682,7 +682,7 @@ def __init__( ) def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: - h = x#.contiguous() + h = x.contiguous() h = self.norm1(h) h = self.nonlinearity(h) From 22ba322a61e5ee81aed92a691a2da131849fa35a Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 11 Jan 2024 14:28:49 +0000 Subject: [PATCH 20/37] Update monai/inferers/inferer.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/inferers/inferer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 13a1161609..ece5715116 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1132,9 +1132,7 @@ def __call__( # type: ignore[override] if self.ldm_latent_shape is not None: latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0) - call = super().__call__ - - prediction: torch.Tensor = call( + prediction: torch.Tensor = super().__call__( inputs=latent, diffusion_model=diffusion_model, noise=noise, From 15af70662ea87c18f5cb6d650894f4aba24a1445 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 11 Jan 2024 14:29:01 +0000 Subject: [PATCH 21/37] Update monai/inferers/inferer.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/inferers/inferer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index ece5715116..19c1d8410d 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1182,9 +1182,7 @@ def sample( # type: ignore[override] "labels for each must be compatible. " ) - sample = super().sample - - outputs = sample( + outputs = super().sample( input_noise=input_noise, diffusion_model=diffusion_model, scheduler=scheduler, From 2f6bda5e9a9e9444873fae791f3fbe595fbebe84 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 11 Jan 2024 14:29:19 +0000 Subject: [PATCH 22/37] Update monai/inferers/inferer.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/inferers/inferer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 19c1d8410d..b61b3ef001 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1269,11 +1269,7 @@ def get_likelihood( # type: ignore[override] if self.ldm_latent_shape is not None: latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) - get_likelihood = super().get_likelihood - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - get_likelihood = partial(super().get_likelihood, seg=seg) - - outputs = get_likelihood( + outputs = super().get_likelihood( inputs=latents, diffusion_model=diffusion_model, scheduler=scheduler, From 22bf240b216491ff38228085aca4f8546eaa8c57 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 11 Jan 2024 14:29:45 +0000 Subject: [PATCH 23/37] Update monai/inferers/inferer.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/inferers/inferer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index b61b3ef001..66085af430 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1277,6 +1277,7 @@ def get_likelihood( # type: ignore[override] conditioning=conditioning, mode=mode, verbose=verbose, + seg=seg ) if save_intermediates and resample_latent_likelihoods: From 095be619e401d05ffef888f8a4308bd74dd84f7d Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 11 Jan 2024 14:30:57 +0000 Subject: [PATCH 24/37] Update monai/inferers/inferer.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/inferers/inferer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 66085af430..644857eeb6 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1645,11 +1645,7 @@ def __call__( # type: ignore[override] if cn_cond.shape[2:] != latent.shape[2:]: cn_cond = F.interpolate(cn_cond, latent.shape[2:]) - call = super().__call__ - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - call = partial(super().__call__, seg=seg) - - prediction = call( + prediction = super().__call__( inputs=latent, diffusion_model=diffusion_model, controlnet=controlnet, From cc28b20a0b0d29ccc06a281e6c5e8849b44225a6 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 11 Jan 2024 14:31:55 +0000 Subject: [PATCH 25/37] Update monai/inferers/inferer.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/inferers/inferer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 644857eeb6..1c3a21cdff 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1704,11 +1704,7 @@ def sample( # type: ignore[override] if cn_cond.shape[2:] != input_noise.shape[2:]: cn_cond = F.interpolate(cn_cond, input_noise.shape[2:]) - sample = super().sample - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - sample = partial(super().sample, seg=seg) - - outputs = sample( + outputs = super().sample( input_noise=input_noise, diffusion_model=diffusion_model, controlnet=controlnet, From 3ba1363e60c6fb39ce53fe87ac493dd6ea2d8091 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 11 Jan 2024 14:32:01 +0000 Subject: [PATCH 26/37] Update monai/inferers/inferer.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/inferers/inferer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 1c3a21cdff..2d7ca1d826 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1801,11 +1801,7 @@ def get_likelihood( # type: ignore[override] if self.ldm_latent_shape is not None: latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) - get_likelihood = super().get_likelihood - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - get_likelihood = partial(super().get_likelihood, seg=seg) - - outputs = get_likelihood( + outputs = super().get_likelihood( inputs=latents, diffusion_model=diffusion_model, controlnet=controlnet, From f162deecf02ad4b6537dae04be56a72ff3e66dcb Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 11 Jan 2024 14:32:12 +0000 Subject: [PATCH 27/37] Update monai/inferers/inferer.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/inferers/inferer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 2d7ca1d826..51e3491b87 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1811,6 +1811,7 @@ def get_likelihood( # type: ignore[override] conditioning=conditioning, mode=mode, verbose=verbose, + seg=seg ) if save_intermediates and resample_latent_likelihoods: From 4c2085c1f9b144fc99178cecfe810f242c1f25e4 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 11 Jan 2024 14:32:25 +0000 Subject: [PATCH 28/37] Update monai/inferers/inferer.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/inferers/inferer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 51e3491b87..45b887e0c2 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1715,6 +1715,7 @@ def sample( # type: ignore[override] conditioning=conditioning, mode=mode, verbose=verbose, + seg=seg ) if save_intermediates: From 4c6d788479e0b36cd3e883d3e6da85168f5ad138 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 11 Jan 2024 14:32:42 +0000 Subject: [PATCH 29/37] Update monai/inferers/inferer.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/inferers/inferer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 45b887e0c2..27ee053fe6 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1654,6 +1654,7 @@ def __call__( # type: ignore[override] cn_cond=cn_cond, condition=condition, mode=mode, + seg=seg ) return prediction From 47b5958f869b08406c10a718c43975d7fd9d3093 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 11 Jan 2024 14:32:58 +0000 Subject: [PATCH 30/37] Update monai/inferers/inferer.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/inferers/inferer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 27ee053fe6..0e5d9fa8b2 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1699,7 +1699,7 @@ def sample( # type: ignore[override] ): raise ValueError( "If both autoencoder_model and diffusion_model implement SPADE, the number of semantic" - "labels for each must be compatible. " + "labels for each must be compatible. Got {autoencoder_model.decoder.label_nc} and {diffusion_model.label_nc}" ) if cn_cond.shape[2:] != input_noise.shape[2:]: From 97bd662ba81fe68ad867a5192894ac43b987f3bc Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 11 Jan 2024 14:42:45 +0000 Subject: [PATCH 31/37] Updates to comments Signed-off-by: Mark Graham --- monai/inferers/inferer.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 0e5d9fa8b2..72bcb8fd5a 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1049,7 +1049,8 @@ def _get_decoder_log_likelihood( original_input_range: the [min,max] intensity range of the input data before any scaling was applied. scaled_input_range: the [min,max] intensity range of the input data after scaling. """ - assert inputs.shape == means.shape + if inputs.shape != means.shape: + raise ValueError(f"Inputs and means must have the same shape, got {inputs.shape} and {means.shape}") bin_width = (scaled_input_range[1] - scaled_input_range[0]) / ( original_input_range[1] - original_input_range[0] ) @@ -1067,7 +1068,6 @@ def _get_decoder_log_likelihood( log_cdf_plus, torch.where(inputs > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), ) - assert log_probs.shape == inputs.shape return log_probs @@ -1178,8 +1178,9 @@ def sample( # type: ignore[override] and autoencoder_model.decoder.label_nc != diffusion_model.label_nc ): raise ValueError( - "If both autoencoder_model and diffusion_model implement SPADE, the number of semantic" - "labels for each must be compatible. " + f"If both autoencoder_model and diffusion_model implement SPADE, the number of semantic" + f"labels for each must be compatible, but got {autoencoder_model.decoder.label_nc} and" + f"{diffusion_model.label_nc}" ) outputs = super().sample( @@ -1277,7 +1278,7 @@ def get_likelihood( # type: ignore[override] conditioning=conditioning, mode=mode, verbose=verbose, - seg=seg + seg=seg, ) if save_intermediates and resample_latent_likelihoods: @@ -1654,7 +1655,7 @@ def __call__( # type: ignore[override] cn_cond=cn_cond, condition=condition, mode=mode, - seg=seg + seg=seg, ) return prediction @@ -1716,7 +1717,7 @@ def sample( # type: ignore[override] conditioning=conditioning, mode=mode, verbose=verbose, - seg=seg + seg=seg, ) if save_intermediates: @@ -1813,7 +1814,7 @@ def get_likelihood( # type: ignore[override] conditioning=conditioning, mode=mode, verbose=verbose, - seg=seg + seg=seg, ) if save_intermediates and resample_latent_likelihoods: From 553c94b326f714c0d745c46cd3b77eb630346d1e Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 15 Jan 2024 16:00:24 +0000 Subject: [PATCH 32/37] Move tests --- test_spade_autoencoderkl.py => tests/test_spade_autoencoderkl.py | 0 .../test_spade_diffusion_model_unet.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename test_spade_autoencoderkl.py => tests/test_spade_autoencoderkl.py (100%) rename test_spade_diffusion_model_unet.py => tests/test_spade_diffusion_model_unet.py (100%) diff --git a/test_spade_autoencoderkl.py b/tests/test_spade_autoencoderkl.py similarity index 100% rename from test_spade_autoencoderkl.py rename to tests/test_spade_autoencoderkl.py diff --git a/test_spade_diffusion_model_unet.py b/tests/test_spade_diffusion_model_unet.py similarity index 100% rename from test_spade_diffusion_model_unet.py rename to tests/test_spade_diffusion_model_unet.py From 38f832a88ef1bc8fba10f57a4c9f43d9392c2e45 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 15 Jan 2024 16:05:00 +0000 Subject: [PATCH 33/37] DCO Remediation Commit for Mark Graham I, Mark Graham , hereby add my Signed-off-by to this commit: 553c94b326f714c0d745c46cd3b77eb630346d1e Signed-off-by: Mark Graham --- tests/test_spade_diffusion_model_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_spade_diffusion_model_unet.py b/tests/test_spade_diffusion_model_unet.py index 113e58ed89..4e4e3be59e 100644 --- a/tests/test_spade_diffusion_model_unet.py +++ b/tests/test_spade_diffusion_model_unet.py @@ -25,7 +25,7 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_res_blocks": 1, + "num_res_fblocks": 1, "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, From 7ef3fb5de993af666b1503a908cf5eb8be472a22 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 15 Jan 2024 16:05:20 +0000 Subject: [PATCH 34/37] DCO Signed-off-by: Mark Graham --- tests/test_spade_diffusion_model_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_spade_diffusion_model_unet.py b/tests/test_spade_diffusion_model_unet.py index 4e4e3be59e..113e58ed89 100644 --- a/tests/test_spade_diffusion_model_unet.py +++ b/tests/test_spade_diffusion_model_unet.py @@ -25,7 +25,7 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_res_fblocks": 1, + "num_res_blocks": 1, "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, From ac891d8d4de07d2a3e6e94eaa352b781d5def4d5 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 17 Jan 2024 09:16:46 +0000 Subject: [PATCH 35/37] Updates setup.cof to fix premerge Signed-off-by: Mark Graham --- setup.cfg | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/setup.cfg b/setup.cfg index 123da68dfa..0069214de3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,7 +52,7 @@ all = scipy>=1.7.1 pillow tensorboard - gdown>=4.4.0 + gdown==4.6.3 pytorch-ignite==0.4.11 torchvision itk>=5.2 @@ -60,12 +60,12 @@ all = lmdb psutil cucim>=23.2.0 - openslide-python==1.1.2 + openslide-python tifffile imagecodecs pandas einops - transformers<4.22 + transformers<4.22; python_version <= '3.10' mlflow>=1.28.0 clearml>=1.10.0rc0 matplotlib @@ -97,7 +97,7 @@ pillow = tensorboard = tensorboard gdown = - gdown>=4.4.0 + gdown==4.6.3 ignite = pytorch-ignite==0.4.11 torchvision = @@ -113,7 +113,7 @@ psutil = cucim = cucim>=23.2.0 openslide = - openslide-python==1.1.2 + openslide-python tifffile = tifffile imagecodecs = @@ -123,7 +123,7 @@ pandas = einops = einops transformers = - transformers<4.22 + transformers<4.22; python_version <= '3.10' mlflow = mlflow matplotlib = @@ -173,6 +173,7 @@ max_line_length = 120 # B028 https://github.com/Project-MONAI/MONAI/issues/5855 # B907 https://github.com/Project-MONAI/MONAI/issues/5868 # B908 https://github.com/Project-MONAI/MONAI/issues/6503 +# B036 https://github.com/Project-MONAI/MONAI/issues/7396 ignore = E203 E501 @@ -186,6 +187,7 @@ ignore = B028 B907 B908 + B036 per_file_ignores = __init__.py: F401, __main__.py: F401 exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,.venv,_version.py From 9073d85f1933b76caab719ee3f136ed342290441 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 17 Jan 2024 09:57:51 +0000 Subject: [PATCH 36/37] Fixes to tests for premerge Signed-off-by: Mark Graham --- tests/test_flexible_unet.py | 2 +- tests/test_invertd.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_flexible_unet.py b/tests/test_flexible_unet.py index 1218ce6e85..1d831f0976 100644 --- a/tests/test_flexible_unet.py +++ b/tests/test_flexible_unet.py @@ -39,7 +39,7 @@ class DummyEncoder(BaseEncoder): def get_encoder_parameters(cls): basic_dict = {"spatial_dims": 2, "in_channels": 3, "pretrained": False} param_dict_list = [basic_dict] - for key in basic_dict: + for key in basic_dict.keys(): cur_dict = basic_dict.copy() del cur_dict[key] param_dict_list.append(cur_dict) diff --git a/tests/test_invertd.py b/tests/test_invertd.py index cd2e91257a..2e6ee35981 100644 --- a/tests/test_invertd.py +++ b/tests/test_invertd.py @@ -112,15 +112,15 @@ def test_invert(self): self.assertTupleEqual(i.shape[1:], (101, 100, 107)) # check the case that different items use different interpolation mode to invert transforms - d = item["image_inverted1"] + j = item["image_inverted1"] # if the interpolation mode is nearest, accumulated diff should be smaller than 1 - self.assertLess(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) - self.assertTupleEqual(d.shape, (1, 101, 100, 107)) + self.assertLess(torch.sum(j.to(torch.float) - j.to(torch.uint8).to(torch.float)).item(), 1.0) + self.assertTupleEqual(j.shape, (1, 101, 100, 107)) - d = item["label_inverted1"] + k = item["label_inverted1"] # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 - self.assertGreater(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) - self.assertTupleEqual(d.shape, (1, 101, 100, 107)) + self.assertGreater(torch.sum(k.to(torch.float) - k.to(torch.uint8).to(torch.float)).item(), 10000.0) + self.assertTupleEqual(k.shape, (1, 101, 100, 107)) # check labels match reverted = item["label_inverted"].detach().cpu().numpy().astype(np.int32) From 1d8e7ccedd178f557e84b432bb26af10ce0e1ec3 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 18 Jan 2024 09:05:24 +0000 Subject: [PATCH 37/37] Remove random test Signed-off-by: Mark Graham --- tests/test_ordering.py | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/tests/test_ordering.py b/tests/test_ordering.py index 0c52dba5e5..e6b235e179 100644 --- a/tests/test_ordering.py +++ b/tests/test_ordering.py @@ -182,24 +182,6 @@ ], ] -TEST_2D_RANDOM = [ - [ - { - "ordering_type": OrderingType.RANDOM, - "spatial_dims": 2, - "dimensions": (1, 2, 2), - "reflected_spatial_dims": (True, False), - "transpositions_axes": ((1, 0),), - "rot90_axes": ((0, 1),), - "transformation_order": ( - OrderingTransformations.TRANSPOSE.value, - OrderingTransformations.ROTATE_90.value, - OrderingTransformations.REFLECT.value, - ), - }, - [[0, 1, 2, 3], [0, 1, 3, 2]], - ] -] TEST_3D = [ [ @@ -291,17 +273,6 @@ def test_ordering_transformation_failure(self, input_param): with self.assertRaises(ValueError): Ordering(**input_param) - @parameterized.expand(TEST_2D_RANDOM) - def test_random(self, input_param, not_in_expected_sequence_ordering): - ordering = Ordering(**input_param) - - not_in = [ - np.array_equal(sequence, ordering.get_sequence_ordering(), equal_nan=True) - for sequence in not_in_expected_sequence_ordering - ] - - self.assertFalse(np.any(not_in)) - @parameterized.expand(TEST_REVERT) def test_revert(self, input_param): sequence = np.random.randint(0, 100, size=input_param["dimensions"]).flatten()