From 6ca3b1bbb88f3ac7193eebe9df946278a7c37ea7 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Fri, 21 Jun 2024 21:10:29 +0000 Subject: [PATCH 01/36] init Signed-off-by: dongyang0122 --- monai/apps/generation/__init__.py | 10 + monai/apps/generation/maisi/__init__.py | 31 + .../generation/maisi/inferers/__init__.py | 10 + .../maisi/inferers/inferer_maisi.py | 788 ++++++ .../generation/maisi/networks/__init__.py | 10 + .../networks/diffusion_model_unet_maisi.py | 2154 +++++++++++++++++ 6 files changed, 3003 insertions(+) create mode 100644 monai/apps/generation/__init__.py create mode 100644 monai/apps/generation/maisi/__init__.py create mode 100644 monai/apps/generation/maisi/inferers/__init__.py create mode 100644 monai/apps/generation/maisi/inferers/inferer_maisi.py create mode 100644 monai/apps/generation/maisi/networks/__init__.py create mode 100644 monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py diff --git a/monai/apps/generation/__init__.py b/monai/apps/generation/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/apps/generation/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/apps/generation/maisi/__init__.py b/monai/apps/generation/maisi/__init__.py new file mode 100644 index 0000000000..ef42d42730 --- /dev/null +++ b/monai/apps/generation/maisi/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import subprocess +import sys + + +def install_and_import(package, package_fullname=None): + if package_fullname is None: + package_fullname = package + + try: + __import__(package) + except ImportError: + print(f"'{package}' is not installed. Installing now...") + subprocess.check_call([sys.executable, "-m", "pip", "install", package_fullname]) + print(f"'{package}' installation completed.") + __import__(package) + + +install_and_import("generative", "monai-generative") diff --git a/monai/apps/generation/maisi/inferers/__init__.py b/monai/apps/generation/maisi/inferers/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/apps/generation/maisi/inferers/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/apps/generation/maisi/inferers/inferer_maisi.py b/monai/apps/generation/maisi/inferers/inferer_maisi.py new file mode 100644 index 0000000000..1110b42c9e --- /dev/null +++ b/monai/apps/generation/maisi/inferers/inferer_maisi.py @@ -0,0 +1,788 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from collections.abc import Callable, Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +from monai.inferers import Inferer +from monai.utils import optional_import + +tqdm, has_tqdm = optional_import("tqdm", name="tqdm") + + +IF_PROFILE = False + + +class DiffusionInferer(Inferer): + """ + DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass + for a training iteration, and sample from the model. + + + Args: + scheduler: diffusion scheduler. + """ + + def __init__(self, scheduler: nn.Module) -> None: + Inferer.__init__(self) + self.scheduler = scheduler + + def __call__( + self, + inputs: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + noise: torch.Tensor, + timesteps: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + top_region_index_tensor: torch.Tensor | None = None, + bottom_region_index_tensor: torch.Tensor | None = None, + spacing_tensor: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: Input image to which noise is added. + diffusion_model: diffusion model. + noise: random noise, of the same shape as the input. + timesteps: random timesteps. + condition: Conditioning for network input. + mode: Conditioning mode for the network. + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + + noisy_image = self.scheduler.add_noise( + original_samples=inputs, noise=noise, timesteps=timesteps + ) + if mode == "concat": + noisy_image = torch.cat([noisy_image, condition], dim=1) + condition = None + prediction = diffusion_model( + x=noisy_image, + timesteps=timesteps, + context=condition, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + ) + + return prediction + + @torch.no_grad() + def sample( + self, + input_noise: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + top_region_index_tensor: torch.Tensor | None = None, + bottom_region_index_tensor: torch.Tensor | None = None, + spacing_tensor: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired sample. + diffusion_model: model to sample from. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + + if not scheduler: + scheduler = self.scheduler + image = input_noise + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + + if IF_PROFILE: + torch.cuda.cudart().cudaProfilerStart() + + for t in progress_bar: + if IF_PROFILE: + torch.cuda.nvtx.range_push("forward") + + # 1. predict noise model_output + if mode == "concat": + model_input = torch.cat([image, conditioning], dim=1) + model_output = diffusion_model( + model_input, + timesteps=torch.Tensor((t,)).to(input_noise.device), + context=None, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + ) + else: + model_output = diffusion_model( + image, + timesteps=torch.Tensor((t,)).to(input_noise.device), + context=conditioning, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + ) + + if IF_PROFILE: + torch.cuda.nvtx.range_pop() + + # diff = torch.norm(model_output).cpu().item() + # print(diff) + # with open("diff.txt", "a") as file: + # file.write(f"{diff}\n") + + # 2. compute previous image: x_t -> x_t-1 + image, _ = scheduler.step(model_output, t, image) + if save_intermediates and t % intermediate_steps == 0: + intermediates.append(image) + + if IF_PROFILE: + torch.cuda.cudart().cudaProfilerStop() + + if save_intermediates: + return image, intermediates + else: + return image + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods for an input. + + Args: + inputs: input images, NxCxHxW[xD] + diffusion_model: model to compute likelihood from + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + """ + + if not scheduler: + scheduler = self.scheduler + if scheduler._get_name() != "DDPMScheduler": + raise NotImplementedError( + f"Likelihood computation is only compatible with DDPMScheduler," + f" you are using {scheduler._get_name()}" + ) + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + noise = torch.randn_like(inputs).to(inputs.device) + total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) + for t in progress_bar: + timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() + noisy_image = self.scheduler.add_noise( + original_samples=inputs, noise=noise, timesteps=timesteps + ) + if mode == "concat": + noisy_image = torch.cat([noisy_image, conditioning], dim=1) + model_output = diffusion_model( + noisy_image, timesteps=timesteps, context=None + ) + else: + model_output = diffusion_model( + x=noisy_image, timesteps=timesteps, context=conditioning + ) + # get the model's predicted mean, and variance if it is predicted + if model_output.shape[1] == inputs.shape[ + 1 + ] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split( + model_output, inputs.shape[1], dim=1 + ) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[t] + alpha_prod_t_prev = ( + scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one + ) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if scheduler.prediction_type == "epsilon": + pred_original_sample = ( + noisy_image - beta_prod_t ** (0.5) * model_output + ) / alpha_prod_t ** (0.5) + elif scheduler.prediction_type == "sample": + pred_original_sample = model_output + elif scheduler.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * noisy_image - ( + beta_prod_t**0.5 + ) * model_output + # 3. Clip "predicted x_0" + if scheduler.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = ( + alpha_prod_t_prev ** (0.5) * scheduler.betas[t] + ) / beta_prod_t + current_sample_coeff = ( + scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + ) + + # 5. Compute predicted previous sample ยต_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + predicted_mean = ( + pred_original_sample_coeff * pred_original_sample + + current_sample_coeff * noisy_image + ) + + # get the posterior mean and variance + posterior_mean = scheduler._get_mean( + timestep=t, x_0=inputs, x_t=noisy_image + ) + posterior_variance = scheduler._get_variance( + timestep=t, predicted_variance=predicted_variance + ) + + log_posterior_variance = torch.log(posterior_variance) + log_predicted_variance = ( + torch.log(predicted_variance) + if predicted_variance + else log_posterior_variance + ) + + if t == 0: + # compute -log p(x_0|x_1) + kl = -self._get_decoder_log_likelihood( + inputs=inputs, + means=predicted_mean, + log_scales=0.5 * log_predicted_variance, + original_input_range=original_input_range, + scaled_input_range=scaled_input_range, + ) + else: + # compute kl between two normals + kl = 0.5 * ( + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) + * torch.exp(-log_predicted_variance) + ) + total_kl += kl.view(kl.shape[0], -1).mean(axis=1) + if save_intermediates: + intermediates.append(kl.cpu()) + + if save_intermediates: + return total_kl, intermediates + else: + return total_kl + + def _approx_standard_normal_cdf(self, x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. Code adapted from https://github.com/openai/improved-diffusion. + """ + + return 0.5 * ( + 1.0 + + torch.tanh( + torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) + * (x + 0.044715 * torch.pow(x, 3)) + ) + ) + + def _get_decoder_log_likelihood( + self, + inputs: torch.Tensor, + means: torch.Tensor, + log_scales: torch.Tensor, + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + ) -> torch.Tensor: + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. Code adapted from https://github.com/openai/improved-diffusion. + + Args: + input: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + means: the Gaussian mean Tensor. + log_scales: the Gaussian log stddev Tensor. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + """ + assert inputs.shape == means.shape + bin_width = (scaled_input_range[1] - scaled_input_range[0]) / ( + original_input_range[1] - original_input_range[0] + ) + centered_x = inputs - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + bin_width / 2) + cdf_plus = self._approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - bin_width / 2) + cdf_min = self._approx_standard_normal_cdf(min_in) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + inputs < -0.999, + log_cdf_plus, + torch.where( + inputs > 0.999, + log_one_minus_cdf_min, + torch.log(cdf_delta.clamp(min=1e-12)), + ), + ) + assert log_probs.shape == inputs.shape + return log_probs + + +class LatentDiffusionInferer(DiffusionInferer): + """ + LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can + be used to perform a signal forward pass for a training iteration, and sample from the model. + + Args: + scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents. + scale_factor: scale factor to multiply the values of the latent representation before processing it by the + second stage. + """ + + def __init__(self, scheduler: nn.Module, scale_factor: float = 1.0) -> None: + super().__init__(scheduler=scheduler) + self.scale_factor = scale_factor + + def __call__( + self, + inputs: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + noise: torch.Tensor, + timesteps: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + top_region_index_tensor: torch.Tensor | None = None, + bottom_region_index_tensor: torch.Tensor | None = None, + spacing_tensor: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: input image to which the latent representation will be extracted and noise is added. + autoencoder_model: first stage model. + diffusion_model: diffusion model. + noise: random noise, of the same shape as the latent representation. + timesteps: random timesteps. + condition: conditioning for network input. + mode: Conditioning mode for the network. + """ + with torch.no_grad(): + latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + prediction = super().__call__( + inputs=latent, + diffusion_model=diffusion_model, + noise=noise, + timesteps=timesteps, + condition=condition, + mode=mode, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + ) + + return prediction + + @torch.no_grad() + def sample( + self, + input_noise: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + top_region_index_tensor: torch.Tensor | None = None, + bottom_region_index_tensor: torch.Tensor | None = None, + spacing_tensor: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired latent representation. + autoencoder_model: first stage model. + diffusion_model: model to sample from. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + """ + outputs = super().sample( + input_noise=input_noise, + diffusion_model=diffusion_model, + scheduler=scheduler, + save_intermediates=save_intermediates, + intermediate_steps=intermediate_steps, + conditioning=conditioning, + mode=mode, + verbose=verbose, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + ) + + if save_intermediates: + latent, latent_intermediates = outputs + else: + latent = outputs + + image = autoencoder_model.decode_stage_2_outputs(latent / self.scale_factor) + + if save_intermediates: + intermediates = [] + for latent_intermediate in latent_intermediates: + intermediates.append( + autoencoder_model.decode_stage_2_outputs( + latent_intermediate / self.scale_factor + ) + ) + return image, intermediates + + else: + return image + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + autoencoder_model: first stage model. + diffusion_model: model to compute likelihood from + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; + """ + if resample_latent_likelihoods and resample_interpolation_mode not in ( + "nearest", + "bilinear", + "trilinear", + ): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) + latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + outputs = super().get_likelihood( + inputs=latents, + diffusion_model=diffusion_model, + scheduler=scheduler, + save_intermediates=save_intermediates, + conditioning=conditioning, + mode=mode, + verbose=verbose, + ) + if save_intermediates and resample_latent_likelihoods: + intermediates = outputs[1] + resizer = nn.Upsample( + size=inputs.shape[2:], mode=resample_interpolation_mode + ) + intermediates = [resizer(x) for x in intermediates] + outputs = (outputs[0], intermediates) + return outputs + + +class VQVAETransformerInferer(Inferer): + """ + Class to perform inference with a VQVAE + Transformer model. + """ + + def __init__(self) -> None: + Inferer.__init__(self) + + def __call__( + self, + inputs: torch.Tensor, + vqvae_model: Callable[..., torch.Tensor], + transformer_model: Callable[..., torch.Tensor], + ordering: Callable[..., torch.Tensor], + condition: torch.Tensor | None = None, + return_latent: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, tuple]: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: input image to which the latent representation will be extracted. + vqvae_model: first stage model. + transformer_model: autoregressive transformer model. + ordering: ordering of the quantised latent representation. + return_latent: also return latent sequence and spatial dim of the latent. + condition: conditioning for network input. + """ + with torch.no_grad(): + latent = vqvae_model.index_quantize(inputs) + + latent_spatial_dim = tuple(latent.shape[1:]) + latent = latent.reshape(latent.shape[0], -1) + latent = latent[:, ordering.get_sequence_ordering()] + + # get the targets for the loss + target = latent.clone() + # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token. + # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens. + latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) + # crop the last token as we do not need the probability of the token that follows it + latent = latent[:, :-1] + latent = latent.long() + + # train on a part of the sequence if it is longer than max_seq_length + seq_len = latent.shape[1] + max_seq_len = transformer_model.max_seq_len + if max_seq_len < seq_len: + start = torch.randint( + low=0, high=seq_len + 1 - max_seq_len, size=(1,) + ).item() + else: + start = 0 + prediction = transformer_model( + x=latent[:, start : start + max_seq_len], context=condition + ) + if return_latent: + return ( + prediction, + target[:, start : start + max_seq_len], + latent_spatial_dim, + ) + else: + return prediction + + @torch.no_grad() + def sample( + self, + latent_spatial_dim: Sequence[int, int, int] | Sequence[int, int], + starting_tokens: torch.Tensor, + vqvae_model: Callable[..., torch.Tensor], + transformer_model: Callable[..., torch.Tensor], + ordering: Callable[..., torch.Tensor], + conditioning: torch.Tensor | None = None, + temperature: float = 1.0, + top_k: int | None = None, + verbose: bool = True, + ) -> torch.Tensor: + """ + Sampling function for the VQVAE + Transformer model. + + Args: + latent_spatial_dim: shape of the sampled image. + starting_tokens: starting tokens for the sampling. It must be vqvae_model.num_embeddings value. + vqvae_model: first stage model. + transformer_model: model to sample from. + conditioning: Conditioning for network input. + temperature: temperature for sampling. + top_k: top k sampling. + verbose: if true, prints the progression bar of the sampling process. + """ + seq_len = math.prod(latent_spatial_dim) + + if verbose and has_tqdm: + progress_bar = tqdm(range(seq_len)) + else: + progress_bar = iter(range(seq_len)) + + latent_seq = starting_tokens.long() + for _ in progress_bar: + # if the sequence context is growing too long we must crop it at block_size + if latent_seq.size(1) <= transformer_model.max_seq_len: + idx_cond = latent_seq + else: + idx_cond = latent_seq[:, -transformer_model.max_seq_len :] + + # forward the model to get the logits for the index in the sequence + logits = transformer_model(x=idx_cond, context=conditioning) + # pluck the logits at the final step and scale by desired temperature + logits = logits[:, -1, :] / temperature + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float("Inf") + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + # remove the chance to be sampled the BOS token + probs[:, vqvae_model.num_embeddings] = 0 + # sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) + # append sampled index to the running sequence and continue + latent_seq = torch.cat((latent_seq, idx_next), dim=1) + + latent_seq = latent_seq[:, 1:] + latent_seq = latent_seq[:, ordering.get_revert_sequence_ordering()] + latent = latent_seq.reshape((starting_tokens.shape[0],) + latent_spatial_dim) + + return vqvae_model.decode_samples(latent) + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + vqvae_model: Callable[..., torch.Tensor], + transformer_model: Callable[..., torch.Tensor], + ordering: Callable[..., torch.Tensor], + condition: torch.Tensor | None = None, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", + verbose: bool = False, + ) -> torch.Tensor: + """ + Computes the log-likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + vqvae_model: first stage model. + transformer_model: autoregressive transformer model. + ordering: ordering of the quantised latent representation. + condition: conditioning for network input. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; + verbose: if true, prints the progression bar of the sampling process. + + """ + if resample_latent_likelihoods and resample_interpolation_mode not in ( + "nearest", + "bilinear", + "trilinear", + ): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) + + with torch.no_grad(): + latent = vqvae_model.index_quantize(inputs) + + latent_spatial_dim = tuple(latent.shape[1:]) + latent = latent.reshape(latent.shape[0], -1) + latent = latent[:, ordering.get_sequence_ordering()] + seq_len = math.prod(latent_spatial_dim) + + # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token. + # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens. + latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) + latent = latent.long() + + # get the first batch, up to max_seq_length, efficiently + logits = transformer_model( + x=latent[:, : transformer_model.max_seq_len], context=condition + ) + probs = F.softmax(logits, dim=-1) + # target token for each set of logits is the next token along + target = latent[:, 1:] + probs = torch.gather( + probs, 2, target[:, : transformer_model.max_seq_len].unsqueeze(2) + ).squeeze(2) + + # if we have not covered the full sequence we continue with inefficient looping + if probs.shape[1] < target.shape[1]: + if verbose and has_tqdm: + progress_bar = tqdm(range(transformer_model.max_seq_len, seq_len)) + else: + progress_bar = iter(range(transformer_model.max_seq_len, seq_len)) + + for i in progress_bar: + idx_cond = latent[:, i + 1 - transformer_model.max_seq_len : i + 1] + # forward the model to get the logits for the index in the sequence + logits = transformer_model(x=idx_cond, context=condition) + # pluck the logits at the final step + logits = logits[:, -1, :] + # apply softmax to convert logits to (normalized) probabilities + p = F.softmax(logits, dim=-1) + # select correct values and append + p = torch.gather(p, 1, target[:, i].unsqueeze(1)) + + probs = torch.cat((probs, p), dim=1) + + # convert to log-likelihood + probs = torch.log(probs) + + # reshape + probs = probs[:, ordering.get_revert_sequence_ordering()] + probs_reshaped = probs.reshape((inputs.shape[0],) + latent_spatial_dim) + if resample_latent_likelihoods: + resizer = nn.Upsample( + size=inputs.shape[2:], mode=resample_interpolation_mode + ) + probs_reshaped = resizer(probs_reshaped[:, None, ...]) + + return probs_reshaped diff --git a/monai/apps/generation/maisi/networks/__init__.py b/monai/apps/generation/maisi/networks/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/apps/generation/maisi/networks/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py new file mode 100644 index 0000000000..03d685a461 --- /dev/null +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -0,0 +1,2154 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +import importlib.util +import math +from collections.abc import Sequence + +import torch +import torch.nn.functional as F +from monai.networks.blocks import Convolution, MLPBlock +from monai.networks.layers.factories import Pool +from monai.utils import ensure_tuple_rep +from torch import nn + +# To install xformers, use pip install xformers==0.0.16rc401 +if importlib.util.find_spec("xformers") is not None: + import xformers + import xformers.ops + + has_xformers = True +else: + xformers = None + has_xformers = False + + +# TODO: Use MONAI's optional_import +# from monai.utils import optional_import +# xformers, has_xformers = optional_import("xformers.ops", name="xformers") + +__all__ = ["CustomDiffusionModelUNet"] + + +def zero_module(module: nn.Module) -> nn.Module: + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class CrossAttention(nn.Module): + """ + A cross attention layer. + + Args: + query_dim: number of channels in the query. + cross_attention_dim: number of channels in the context. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each head. + dropout: dropout probability to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: int | None = None, + num_attention_heads: int = 8, + num_head_channels: int = 64, + dropout: float = 0.0, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.use_flash_attention = use_flash_attention + inner_dim = num_head_channels * num_attention_heads + cross_attention_dim = ( + cross_attention_dim if cross_attention_dim is not None else query_dim + ) + + self.scale = 1 / math.sqrt(num_head_channels) + self.num_heads = num_attention_heads + + self.upcast_attention = upcast_attention + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + """ + Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. + """ + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape( + batch_size * self.num_heads, seq_len, dim // self.num_heads + ) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + """Combine the output of the attention heads back into the hidden state dimension.""" + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape( + batch_size // self.num_heads, seq_len, dim * self.num_heads + ) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty( + query.shape[0], + query.shape[1], + key.shape[1], + dtype=query.dtype, + device=query.device, + ), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_probs.to(dtype=dtype) + + x = torch.bmm(attention_probs, value) + return x + + def forward( + self, x: torch.Tensor, context: torch.Tensor | None = None + ) -> torch.Tensor: + query = self.to_q(x) + context = context if context is not None else x + key = self.to_k(context) + value = self.to_v(context) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if self.use_flash_attention: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) + + return self.to_out(x) + + +class BasicTransformerBlock(nn.Module): + """ + A basic Transformer block. + + Args: + num_channels: number of channels in the input and output. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each attention head. + dropout: dropout probability to use. + cross_attention_dim: size of the context vector for cross attention. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + num_channels: int, + num_attention_heads: int, + num_head_channels: int, + dropout: float = 0.0, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.attn1 = CrossAttention( + query_dim=num_channels, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) # is a self-attention + self.ff = MLPBlock( + hidden_size=num_channels, + mlp_dim=num_channels * 4, + act="GEGLU", + dropout_rate=dropout, + ) + self.attn2 = CrossAttention( + query_dim=num_channels, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) # is a self-attention if context is None + self.norm1 = nn.LayerNorm(num_channels) + self.norm2 = nn.LayerNorm(num_channels) + self.norm3 = nn.LayerNorm(num_channels) + + def forward( + self, x: torch.Tensor, context: torch.Tensor | None = None + ) -> torch.Tensor: + # 1. Self-Attention + x = self.attn1(self.norm1(x)) + x + + # 2. Cross-Attention + x = self.attn2(self.norm2(x), context=context) + x + + # 3. Feed-forward + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply + standard transformer action. Finally, reshape to image. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of channels in the input and output. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each attention head. + num_layers: number of layers of Transformer blocks to use. + dropout: dropout probability to use. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_attention_heads: int, + num_head_channels: int, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + inner_dim = num_attention_heads * num_head_channels + + self.norm = nn.GroupNorm( + num_groups=norm_num_groups, + num_channels=in_channels, + eps=norm_eps, + affine=True, + ) + + self.proj_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=inner_dim, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + num_channels=inner_dim, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + for _ in range(num_layers) + ] + ) + + self.proj_out = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=inner_dim, + out_channels=in_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + ) + + def forward( + self, x: torch.Tensor, context: torch.Tensor | None = None + ) -> torch.Tensor: + # note: if no context is given, cross-attention defaults to self-attention + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + residual = x + x = self.norm(x) + x = self.proj_in(x) + + inner_dim = x.shape[1] + + if self.spatial_dims == 2: + x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + if self.spatial_dims == 3: + x = x.permute(0, 2, 3, 4, 1).reshape( + batch, height * width * depth, inner_dim + ) + + for block in self.transformer_blocks: + x = block(x, context=context) + + if self.spatial_dims == 2: + x = ( + x.reshape(batch, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + if self.spatial_dims == 3: + x = ( + x.reshape(batch, height, width, depth, inner_dim) + .permute(0, 4, 1, 2, 3) + .contiguous() + ) + + x = self.proj_out(x) + return x + residual + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. Uses three q, k, v linear layers to + compute attention. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + num_head_channels: number of channels in each attention head. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon value to use for the normalisation. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: int, + num_head_channels: int | None = None, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.use_flash_attention = use_flash_attention + self.spatial_dims = spatial_dims + self.num_channels = num_channels + + self.num_heads = ( + num_channels // num_head_channels if num_head_channels is not None else 1 + ) + self.scale = 1 / math.sqrt(num_channels / self.num_heads) + + self.norm = nn.GroupNorm( + num_groups=norm_num_groups, + num_channels=num_channels, + eps=norm_eps, + affine=True, + ) + + self.to_q = nn.Linear(num_channels, num_channels) + self.to_k = nn.Linear(num_channels, num_channels) + self.to_v = nn.Linear(num_channels, num_channels) + + self.proj_attn = nn.Linear(num_channels, num_channels) + + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape( + batch_size * self.num_heads, seq_len, dim // self.num_heads + ) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape( + batch_size // self.num_heads, seq_len, dim * self.num_heads + ) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + attention_scores = torch.baddbmm( + torch.empty( + query.shape[0], + query.shape[1], + key.shape[1], + dtype=query.dtype, + device=query.device, + ), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + x = torch.bmm(attention_probs, value) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + # norm + x = self.norm(x) + + if self.spatial_dims == 2: + x = x.view(batch, channel, height * width).transpose(1, 2) + if self.spatial_dims == 3: + x = x.view(batch, channel, height * width * depth).transpose(1, 2) + + # proj to q, k, v + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if self.use_flash_attention: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) + + if self.spatial_dims == 2: + x = x.transpose(-1, -2).reshape(batch, channel, height, width) + if self.spatial_dims == 3: + x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) + + return x + residual + + +def get_timestep_embedding( + timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000 +) -> torch.Tensor: + """ + Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic + Models" https://arxiv.org/abs/2006.11239. + + Args: + timesteps: a 1-D Tensor of N indices, one per batch element. + embedding_dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + """ + # print(f'max_period: {max_period}; timesteps: {torch.norm(timesteps.float(), p=2)}; embedding_dim: {embedding_dim}') + + if timesteps.ndim != 1: + raise ValueError("Timesteps should be a 1d-array") + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + freqs = torch.exp(exponent / half_dim) + + args = timesteps[:, None].float() * freqs[None, :] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0)) + + return embedding + + +class Downsample(nn.Module): + """ + Downsampling layer. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is + False, the number of output channels must be the same as the number of input channels. + out_channels: number of output channels. + padding: controls the amount of implicit zero-paddings on both sides for padding number of points + for each dimension. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: int, + use_conv: bool, + out_channels: int | None = None, + padding: int = 1, + ) -> None: + super().__init__() + self.num_channels = num_channels + self.out_channels = out_channels or num_channels + self.use_conv = use_conv + if use_conv: + self.op = Convolution( + spatial_dims=spatial_dims, + in_channels=self.num_channels, + out_channels=self.out_channels, + strides=2, + kernel_size=3, + padding=padding, + conv_only=True, + ) + else: + if self.num_channels != self.out_channels: + raise ValueError( + "num_channels and out_channels must be equal when use_conv=False" + ) + self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2) + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + del emb + if x.shape[1] != self.num_channels: + raise ValueError( + f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels " + f"({self.num_channels})" + ) + return self.op(x) + + +class Upsample(nn.Module): + """ + Upsampling layer with an optional convolution. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + use_conv: if True uses Convolution instead of Pool average to perform downsampling. + out_channels: number of output channels. + padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each + dimension. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: int, + use_conv: bool, + out_channels: int | None = None, + padding: int = 1, + ) -> None: + super().__init__() + self.num_channels = num_channels + self.out_channels = out_channels or num_channels + self.use_conv = use_conv + if use_conv: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=self.num_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=padding, + conv_only=True, + ) + else: + self.conv = None + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + del emb + if x.shape[1] != self.num_channels: + raise ValueError("Input channels should be equal to num_channels") + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # https://github.com/pytorch/pytorch/issues/86679 + dtype = x.dtype + if dtype == torch.bfloat16: + x = x.to(torch.float32) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + x = x.to(dtype) + + if self.use_conv: + x = self.conv(x) + return x + + +class ResnetBlock(nn.Module): + """ + Residual block with timestep conditioning. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + out_channels: number of output channels. + up: if True, performs upsampling. + down: if True, performs downsampling. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + out_channels: int | None = None, + up: bool = False, + down: bool = False, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = in_channels + self.emb_channels = temb_channels + self.out_channels = out_channels or in_channels + self.up = up + self.down = down + + self.norm1 = nn.GroupNorm( + num_groups=norm_num_groups, + num_channels=in_channels, + eps=norm_eps, + affine=True, + ) + self.nonlinearity = nn.SiLU() + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.upsample = self.downsample = None + if self.up: + self.upsample = Upsample(spatial_dims, in_channels, use_conv=False) + elif down: + self.downsample = Downsample(spatial_dims, in_channels, use_conv=False) + + self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) + + self.norm2 = nn.GroupNorm( + num_groups=norm_num_groups, + num_channels=self.out_channels, + eps=norm_eps, + affine=True, + ) + self.conv2 = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + if self.out_channels == in_channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = self.nonlinearity(h) + + if self.upsample is not None: + if h.shape[0] >= 64: + x = x.contiguous() + h = h.contiguous() + x = self.upsample(x) + h = self.upsample(h) + elif self.downsample is not None: + x = self.downsample(x) + h = self.downsample(h) + + h = self.conv1(h) + + if self.spatial_dims == 2: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] + else: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] + h = h + temb + + h = self.norm2(h) + h = self.nonlinearity(h) + h = self.conv2(h) + + return self.skip_connection(x) + h + + +class DownBlock(nn.Module): + """ + Unet's down block containing resnet and downsamplers blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + if resblock_updown: + self.downsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + del context + output_states = [] + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class AttnDownBlock(nn.Module): + """ + Unet's down block containing resnet, downsamplers and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + if resblock_updown: + self.downsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + del context + output_states = [] + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class CrossAttnDownBlock(nn.Module): + """ + Unet's down block containing resnet, downsamplers and cross-attention blocks. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + attentions.append( + SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + num_layers=transformer_num_layers, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + if resblock_updown: + self.downsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + output_states = [] + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=context) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class AttnMidBlock(nn.Module): + """ + Unet's mid block containing resnet and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.attention = None + + self.resnet_1 = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + self.attention = AttentionBlock( + spatial_dims=spatial_dims, + num_channels=in_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + + self.resnet_2 = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + hidden_states = self.resnet_1(hidden_states, temb) + hidden_states = self.attention(hidden_states) + hidden_states = self.resnet_2(hidden_states, temb) + + return hidden_states + + +class CrossAttnMidBlock(nn.Module): + """ + Unet's mid block containing resnet and cross-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + self.attention = None + + self.resnet_1 = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + self.attention = SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=in_channels, + num_attention_heads=in_channels // num_head_channels, + num_head_channels=num_head_channels, + num_layers=transformer_num_layers, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + self.resnet_2 = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + hidden_states = self.resnet_1(hidden_states, temb) + hidden_states = self.attention(hidden_states, context=context) + hidden_states = self.resnet_2(hidden_states, temb) + + return hidden_states + + +class UpBlock(nn.Module): + """ + Unet's up block containing resnet and upsamplers blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + + for i in range(num_res_blocks): + res_skip_channels = ( + in_channels if (i == num_res_blocks - 1) else out_channels + ) + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + if resblock_updown: + self.upsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = Upsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class AttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = ( + in_channels if (i == num_res_blocks - 1) else out_channels + ) + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_upsample: + if resblock_updown: + self.upsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = Upsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class CrossAttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = ( + in_channels if (i == num_res_blocks - 1) else out_channels + ) + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + if resblock_updown: + self.upsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = Upsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=context) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +def get_down_block( + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_downsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_attn: + return AttnDownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + elif with_cross_attn: + return CrossAttnDownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + else: + return DownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + ) + + +def get_mid_block( + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int, + norm_eps: float, + with_conditioning: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_conditioning: + return CrossAttnMidBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + else: + return AttnMidBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + + +def get_up_block( + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_upsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_attn: + return AttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + elif with_cross_attn: + return CrossAttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + else: + return UpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + ) + + +class CustomDiffusionModelUNet(nn.Module): + """ + Unet network with timestep embedding and attention mechanisms for conditioning based on + Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 + and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see ResnetBlock) per level. + num_channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + input_top_region_index: bool = False, + input_bottom_region_index: bool = False, + input_spacing: bool = False, + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "CustomDiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "CustomDiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + ) + if dropout_cattn > 1.0 or dropout_cattn < 0.0: + raise ValueError("Dropout cannot be negative or >1.0!") + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + raise ValueError( + "CustomDiffusionModelUNet expects all num_channels being multiple of norm_num_groups" + ) + + if len(num_channels) != len(attention_levels): + raise ValueError( + "CustomDiffusionModelUNet expects num_channels being same size of attention_levels" + ) + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep( + num_head_channels, len(attention_levels) + ) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + + if len(num_res_blocks) != len(num_channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if use_flash_attention and not has_xformers: + raise ValueError( + "use_flash_attention is True but xformers is not installed." + ) + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.in_channels = in_channels + self.block_out_channels = num_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = num_channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(num_channels[0], time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + self.input_top_region_index = input_top_region_index + self.input_bottom_region_index = input_bottom_region_index + self.input_spacing = input_spacing + + new_time_embed_dim = time_embed_dim + if self.input_top_region_index: + # self.top_region_index_layer = nn.Linear(4, time_embed_dim) + self.top_region_index_layer = nn.Sequential( + nn.Linear(4, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + new_time_embed_dim += time_embed_dim + if self.input_bottom_region_index: + # self.bottom_region_index_layer = nn.Linear(4, time_embed_dim) + self.bottom_region_index_layer = nn.Sequential( + nn.Linear(4, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + new_time_embed_dim += time_embed_dim + if self.input_spacing: + # self.spacing_layer = nn.Linear(3, time_embed_dim) + self.spacing_layer = nn.Sequential( + nn.Linear(3, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + new_time_embed_dim += time_embed_dim + + # down + self.down_blocks = nn.ModuleList([]) + output_channel = num_channels[0] + for i in range(len(num_channels)): + input_channel = output_channel + output_channel = num_channels[i] + is_final_block = i == len(num_channels) - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=new_time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + self.down_blocks.append(down_block) + + # mid + self.middle_block = get_mid_block( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + temb_channels=new_time_embed_dim, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + with_conditioning=with_conditioning, + num_head_channels=num_head_channels[-1], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + # up + self.up_blocks = nn.ModuleList([]) + reversed_block_out_channels = list(reversed(num_channels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_head_channels = list(reversed(num_head_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[ + min(i + 1, len(num_channels) - 1) + ] + + is_final_block = i == len(num_channels) - 1 + + up_block = get_up_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + prev_output_channel=prev_output_channel, + out_channels=output_channel, + temb_channels=new_time_embed_dim, + num_res_blocks=reversed_num_res_blocks[i] + 1, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(reversed_attention_levels[i] and not with_conditioning), + with_cross_attn=(reversed_attention_levels[i] and with_conditioning), + num_head_channels=reversed_num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + self.up_blocks.append(up_block) + + # out + self.out = nn.Sequential( + nn.GroupNorm( + num_groups=norm_num_groups, + num_channels=num_channels[0], + eps=norm_eps, + affine=True, + ), + nn.SiLU(), + zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=num_channels[0], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ), + ) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + down_block_additional_residuals: tuple[torch.Tensor] | None = None, + mid_block_additional_residual: torch.Tensor | None = None, + top_region_index_tensor: torch.Tensor | None = None, + bottom_region_index_tensor: torch.Tensor | None = None, + spacing_tensor: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). + mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + # print(f't_emb: {t_emb}; timesteps {timesteps}.') + # print(f'emb: {torch.norm(emb, p=2)}; t_emb: {torch.norm(t_emb, p=2)}') + # print(f"emb: {torch.norm(emb)}") + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError( + "class_labels should be provided when num_class_embeds > 0" + ) + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. input + if self.input_top_region_index: + _emb = self.top_region_index_layer(top_region_index_tensor) + # print(f"top_region_index_layer: {torch.norm(_emb)} {_emb.size()}") + # emb = emb + _emb.to(dtype=x.dtype) + emb = torch.cat((emb, _emb), dim=1) + # print(f'emb: {emb.size()}, {torch.norm(emb, p=2)}; top_region_index_tensor: {torch.norm(_emb, p=2)}') + if self.input_bottom_region_index: + _emb = self.bottom_region_index_layer(bottom_region_index_tensor) + # print(f"bottom_region_index_layer: {torch.norm(_emb)}") + # emb = emb + _emb.to(dtype=x.dtype) + emb = torch.cat((emb, _emb), dim=1) + # print(f'emb: {emb.size()}, {torch.norm(emb, p=2)}; bottom_region_index_tensor: {torch.norm(_emb, p=2)}') + if self.input_spacing: + _emb = self.spacing_layer(spacing_tensor) + # print(f"spacing_layer: {torch.norm(_emb)}") + # emb = emb + _emb.to(dtype=x.dtype) + emb = torch.cat((emb, _emb), dim=1) + # print(f'emb: {emb.size()}, {torch.norm(emb, p=2)}; spacing_tensor: {torch.norm(spacing_tensor, p=2)}') + + # 3. initial convolution + h = self.conv_in(x) + # print(f"x: {torch.norm(x)}; h: {torch.norm(h)}") + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError( + "model should have with_conditioning = True if context is provided" + ) + down_block_res_samples: list[torch.Tensor] = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block( + hidden_states=h, temb=emb, context=context + ) + for residual in res_samples: + down_block_res_samples.append(residual) + + # Additional residual conections for Controlnets + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = ( + down_block_res_sample + down_block_additional_residual + ) + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # Additional residual conections for Controlnets + if mid_block_additional_residual is not None: + h = h + mid_block_additional_residual + + # 6. up + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[ + : -len(upsample_block.resnets) + ] + h = upsample_block( + hidden_states=h, + res_hidden_states_list=res_samples, + temb=emb, + context=context, + ) + + # 7. output block + h = self.out(h) + + return h From 2989c504db26dc426c7297eaab9cb33e128e64ee Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Jun 2024 21:14:15 +0000 Subject: [PATCH 02/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/generation/maisi/inferers/inferer_maisi.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/apps/generation/maisi/inferers/inferer_maisi.py b/monai/apps/generation/maisi/inferers/inferer_maisi.py index 1110b42c9e..5b4aec5950 100644 --- a/monai/apps/generation/maisi/inferers/inferer_maisi.py +++ b/monai/apps/generation/maisi/inferers/inferer_maisi.py @@ -148,10 +148,10 @@ def sample( bottom_region_index_tensor=bottom_region_index_tensor, spacing_tensor=spacing_tensor, ) - + if IF_PROFILE: torch.cuda.nvtx.range_pop() - + # diff = torch.norm(model_output).cpu().item() # print(diff) # with open("diff.txt", "a") as file: @@ -161,7 +161,7 @@ def sample( image, _ = scheduler.step(model_output, t, image) if save_intermediates and t % intermediate_steps == 0: intermediates.append(image) - + if IF_PROFILE: torch.cuda.cudart().cudaProfilerStop() From 39f646916ca2adcd1ec009f3ce2c59e8b824fd0f Mon Sep 17 00:00:00 2001 From: Dong Yang Date: Fri, 21 Jun 2024 15:29:32 -0600 Subject: [PATCH 03/36] update Signed-off-by: Dong Yang --- .../networks/diffusion_model_unet_maisi.py | 2065 +---------------- 1 file changed, 40 insertions(+), 2025 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index 03d685a461..b4b4bc878c 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -31,1880 +31,69 @@ from __future__ import annotations -import importlib.util -import math from collections.abc import Sequence import torch -import torch.nn.functional as F -from monai.networks.blocks import Convolution, MLPBlock -from monai.networks.layers.factories import Pool -from monai.utils import ensure_tuple_rep from torch import nn -# To install xformers, use pip install xformers==0.0.16rc401 -if importlib.util.find_spec("xformers") is not None: - import xformers - import xformers.ops +__all__ = ["DiffusionModelUNetMaisi"] - has_xformers = True -else: - xformers = None - has_xformers = False +from monai.networks.nets.diffusion_model_unet import DiffusionModelUNet +from generative.networks.nets.diffusion_model_unet import get_timestep_embedding -# TODO: Use MONAI's optional_import -# from monai.utils import optional_import -# xformers, has_xformers = optional_import("xformers.ops", name="xformers") - -__all__ = ["CustomDiffusionModelUNet"] - - -def zero_module(module: nn.Module) -> nn.Module: - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -class CrossAttention(nn.Module): - """ - A cross attention layer. - - Args: - query_dim: number of channels in the query. - cross_attention_dim: number of channels in the context. - num_attention_heads: number of heads to use for multi-head attention. - num_head_channels: number of channels in each head. - dropout: dropout probability to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - query_dim: int, - cross_attention_dim: int | None = None, - num_attention_heads: int = 8, - num_head_channels: int = 64, - dropout: float = 0.0, - upcast_attention: bool = False, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.use_flash_attention = use_flash_attention - inner_dim = num_head_channels * num_attention_heads - cross_attention_dim = ( - cross_attention_dim if cross_attention_dim is not None else query_dim - ) - - self.scale = 1 / math.sqrt(num_head_channels) - self.num_heads = num_attention_heads - - self.upcast_attention = upcast_attention - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) - self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) - ) - - def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: - """ - Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. - """ - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) - x = x.permute(0, 2, 1, 3).reshape( - batch_size * self.num_heads, seq_len, dim // self.num_heads - ) - return x - - def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: - """Combine the output of the attention heads back into the hidden state dimension.""" - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) - x = x.permute(0, 2, 1, 3).reshape( - batch_size // self.num_heads, seq_len, dim * self.num_heads - ) - return x - - def _memory_efficient_attention_xformers( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) - return x - - def _attention( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - dtype = query.dtype - if self.upcast_attention: - query = query.float() - key = key.float() - - attention_scores = torch.baddbmm( - torch.empty( - query.shape[0], - query.shape[1], - key.shape[1], - dtype=query.dtype, - device=query.device, - ), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - attention_probs = attention_scores.softmax(dim=-1) - attention_probs = attention_probs.to(dtype=dtype) - - x = torch.bmm(attention_probs, value) - return x - - def forward( - self, x: torch.Tensor, context: torch.Tensor | None = None - ) -> torch.Tensor: - query = self.to_q(x) - context = context if context is not None else x - key = self.to_k(context) - value = self.to_v(context) - - # Multi-Head Attention - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - - if self.use_flash_attention: - x = self._memory_efficient_attention_xformers(query, key, value) - else: - x = self._attention(query, key, value) - - x = self.reshape_batch_dim_to_heads(x) - x = x.to(query.dtype) - - return self.to_out(x) - - -class BasicTransformerBlock(nn.Module): - """ - A basic Transformer block. - - Args: - num_channels: number of channels in the input and output. - num_attention_heads: number of heads to use for multi-head attention. - num_head_channels: number of channels in each attention head. - dropout: dropout probability to use. - cross_attention_dim: size of the context vector for cross attention. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - num_channels: int, - num_attention_heads: int, - num_head_channels: int, - dropout: float = 0.0, - cross_attention_dim: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.attn1 = CrossAttention( - query_dim=num_channels, - num_attention_heads=num_attention_heads, - num_head_channels=num_head_channels, - dropout=dropout, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) # is a self-attention - self.ff = MLPBlock( - hidden_size=num_channels, - mlp_dim=num_channels * 4, - act="GEGLU", - dropout_rate=dropout, - ) - self.attn2 = CrossAttention( - query_dim=num_channels, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - num_head_channels=num_head_channels, - dropout=dropout, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) # is a self-attention if context is None - self.norm1 = nn.LayerNorm(num_channels) - self.norm2 = nn.LayerNorm(num_channels) - self.norm3 = nn.LayerNorm(num_channels) - - def forward( - self, x: torch.Tensor, context: torch.Tensor | None = None - ) -> torch.Tensor: - # 1. Self-Attention - x = self.attn1(self.norm1(x)) + x - - # 2. Cross-Attention - x = self.attn2(self.norm2(x), context=context) + x - - # 3. Feed-forward - x = self.ff(self.norm3(x)) + x - return x - - -class SpatialTransformer(nn.Module): - """ - Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply - standard transformer action. Finally, reshape to image. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of channels in the input and output. - num_attention_heads: number of heads to use for multi-head attention. - num_head_channels: number of channels in each attention head. - num_layers: number of layers of Transformer blocks to use. - dropout: dropout probability to use. - norm_num_groups: number of groups for the normalization. - norm_eps: epsilon for the normalization. - cross_attention_dim: number of context dimensions to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - num_attention_heads: int, - num_head_channels: int, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - cross_attention_dim: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.spatial_dims = spatial_dims - self.in_channels = in_channels - inner_dim = num_attention_heads * num_head_channels - - self.norm = nn.GroupNorm( - num_groups=norm_num_groups, - num_channels=in_channels, - eps=norm_eps, - affine=True, - ) - - self.proj_in = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=inner_dim, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - num_channels=inner_dim, - num_attention_heads=num_attention_heads, - num_head_channels=num_head_channels, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) - for _ in range(num_layers) - ] - ) - - self.proj_out = zero_module( - Convolution( - spatial_dims=spatial_dims, - in_channels=inner_dim, - out_channels=in_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - ) - - def forward( - self, x: torch.Tensor, context: torch.Tensor | None = None - ) -> torch.Tensor: - # note: if no context is given, cross-attention defaults to self-attention - batch = channel = height = width = depth = -1 - if self.spatial_dims == 2: - batch, channel, height, width = x.shape - if self.spatial_dims == 3: - batch, channel, height, width, depth = x.shape - - residual = x - x = self.norm(x) - x = self.proj_in(x) - - inner_dim = x.shape[1] - - if self.spatial_dims == 2: - x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - if self.spatial_dims == 3: - x = x.permute(0, 2, 3, 4, 1).reshape( - batch, height * width * depth, inner_dim - ) - - for block in self.transformer_blocks: - x = block(x, context=context) - - if self.spatial_dims == 2: - x = ( - x.reshape(batch, height, width, inner_dim) - .permute(0, 3, 1, 2) - .contiguous() - ) - if self.spatial_dims == 3: - x = ( - x.reshape(batch, height, width, depth, inner_dim) - .permute(0, 4, 1, 2, 3) - .contiguous() - ) - - x = self.proj_out(x) - return x + residual - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. Uses three q, k, v linear layers to - compute attention. - - Args: - spatial_dims: number of spatial dimensions. - num_channels: number of input channels. - num_head_channels: number of channels in each attention head. - norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of - channels is divisible by this number. - norm_eps: epsilon value to use for the normalisation. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - num_channels: int, - num_head_channels: int | None = None, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.use_flash_attention = use_flash_attention - self.spatial_dims = spatial_dims - self.num_channels = num_channels - - self.num_heads = ( - num_channels // num_head_channels if num_head_channels is not None else 1 - ) - self.scale = 1 / math.sqrt(num_channels / self.num_heads) - - self.norm = nn.GroupNorm( - num_groups=norm_num_groups, - num_channels=num_channels, - eps=norm_eps, - affine=True, - ) - - self.to_q = nn.Linear(num_channels, num_channels) - self.to_k = nn.Linear(num_channels, num_channels) - self.to_v = nn.Linear(num_channels, num_channels) - - self.proj_attn = nn.Linear(num_channels, num_channels) - - def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) - x = x.permute(0, 2, 1, 3).reshape( - batch_size * self.num_heads, seq_len, dim // self.num_heads - ) - return x - - def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) - x = x.permute(0, 2, 1, 3).reshape( - batch_size // self.num_heads, seq_len, dim * self.num_heads - ) - return x - - def _memory_efficient_attention_xformers( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) - return x - - def _attention( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - attention_scores = torch.baddbmm( - torch.empty( - query.shape[0], - query.shape[1], - key.shape[1], - dtype=query.dtype, - device=query.device, - ), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - attention_probs = attention_scores.softmax(dim=-1) - x = torch.bmm(attention_probs, value) - return x - - def forward(self, x: torch.Tensor) -> torch.Tensor: - residual = x - - batch = channel = height = width = depth = -1 - if self.spatial_dims == 2: - batch, channel, height, width = x.shape - if self.spatial_dims == 3: - batch, channel, height, width, depth = x.shape - - # norm - x = self.norm(x) - - if self.spatial_dims == 2: - x = x.view(batch, channel, height * width).transpose(1, 2) - if self.spatial_dims == 3: - x = x.view(batch, channel, height * width * depth).transpose(1, 2) - - # proj to q, k, v - query = self.to_q(x) - key = self.to_k(x) - value = self.to_v(x) - - # Multi-Head Attention - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - - if self.use_flash_attention: - x = self._memory_efficient_attention_xformers(query, key, value) - else: - x = self._attention(query, key, value) - - x = self.reshape_batch_dim_to_heads(x) - x = x.to(query.dtype) - - if self.spatial_dims == 2: - x = x.transpose(-1, -2).reshape(batch, channel, height, width) - if self.spatial_dims == 3: - x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) - - return x + residual - - -def get_timestep_embedding( - timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000 -) -> torch.Tensor: - """ - Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic - Models" https://arxiv.org/abs/2006.11239. - - Args: - timesteps: a 1-D Tensor of N indices, one per batch element. - embedding_dim: the dimension of the output. - max_period: controls the minimum frequency of the embeddings. - """ - # print(f'max_period: {max_period}; timesteps: {torch.norm(timesteps.float(), p=2)}; embedding_dim: {embedding_dim}') - - if timesteps.ndim != 1: - raise ValueError("Timesteps should be a 1d-array") - - half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange( - start=0, end=half_dim, dtype=torch.float32, device=timesteps.device - ) - freqs = torch.exp(exponent / half_dim) - - args = timesteps[:, None].float() * freqs[None, :] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - - # zero pad - if embedding_dim % 2 == 1: - embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0)) - - return embedding - - -class Downsample(nn.Module): - """ - Downsampling layer. - - Args: - spatial_dims: number of spatial dimensions. - num_channels: number of input channels. - use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is - False, the number of output channels must be the same as the number of input channels. - out_channels: number of output channels. - padding: controls the amount of implicit zero-paddings on both sides for padding number of points - for each dimension. - """ - - def __init__( - self, - spatial_dims: int, - num_channels: int, - use_conv: bool, - out_channels: int | None = None, - padding: int = 1, - ) -> None: - super().__init__() - self.num_channels = num_channels - self.out_channels = out_channels or num_channels - self.use_conv = use_conv - if use_conv: - self.op = Convolution( - spatial_dims=spatial_dims, - in_channels=self.num_channels, - out_channels=self.out_channels, - strides=2, - kernel_size=3, - padding=padding, - conv_only=True, - ) - else: - if self.num_channels != self.out_channels: - raise ValueError( - "num_channels and out_channels must be equal when use_conv=False" - ) - self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2) - - def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: - del emb - if x.shape[1] != self.num_channels: - raise ValueError( - f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels " - f"({self.num_channels})" - ) - return self.op(x) - - -class Upsample(nn.Module): - """ - Upsampling layer with an optional convolution. - - Args: - spatial_dims: number of spatial dimensions. - num_channels: number of input channels. - use_conv: if True uses Convolution instead of Pool average to perform downsampling. - out_channels: number of output channels. - padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each - dimension. - """ - - def __init__( - self, - spatial_dims: int, - num_channels: int, - use_conv: bool, - out_channels: int | None = None, - padding: int = 1, - ) -> None: - super().__init__() - self.num_channels = num_channels - self.out_channels = out_channels or num_channels - self.use_conv = use_conv - if use_conv: - self.conv = Convolution( - spatial_dims=spatial_dims, - in_channels=self.num_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=padding, - conv_only=True, - ) - else: - self.conv = None - - def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: - del emb - if x.shape[1] != self.num_channels: - raise ValueError("Input channels should be equal to num_channels") - - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # https://github.com/pytorch/pytorch/issues/86679 - dtype = x.dtype - if dtype == torch.bfloat16: - x = x.to(torch.float32) - - x = F.interpolate(x, scale_factor=2.0, mode="nearest") - - # If the input is bfloat16, we cast back to bfloat16 - if dtype == torch.bfloat16: - x = x.to(dtype) - - if self.use_conv: - x = self.conv(x) - return x - - -class ResnetBlock(nn.Module): - """ - Residual block with timestep conditioning. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - temb_channels: number of timestep embedding channels. - out_channels: number of output channels. - up: if True, performs upsampling. - down: if True, performs downsampling. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - temb_channels: int, - out_channels: int | None = None, - up: bool = False, - down: bool = False, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - ) -> None: - super().__init__() - self.spatial_dims = spatial_dims - self.channels = in_channels - self.emb_channels = temb_channels - self.out_channels = out_channels or in_channels - self.up = up - self.down = down - - self.norm1 = nn.GroupNorm( - num_groups=norm_num_groups, - num_channels=in_channels, - eps=norm_eps, - affine=True, - ) - self.nonlinearity = nn.SiLU() - self.conv1 = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - - self.upsample = self.downsample = None - if self.up: - self.upsample = Upsample(spatial_dims, in_channels, use_conv=False) - elif down: - self.downsample = Downsample(spatial_dims, in_channels, use_conv=False) - - self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) - - self.norm2 = nn.GroupNorm( - num_groups=norm_num_groups, - num_channels=self.out_channels, - eps=norm_eps, - affine=True, - ) - self.conv2 = zero_module( - Convolution( - spatial_dims=spatial_dims, - in_channels=self.out_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ) - - if self.out_channels == in_channels: - self.skip_connection = nn.Identity() - else: - self.skip_connection = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - - def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: - h = x - h = self.norm1(h) - h = self.nonlinearity(h) - - if self.upsample is not None: - if h.shape[0] >= 64: - x = x.contiguous() - h = h.contiguous() - x = self.upsample(x) - h = self.upsample(h) - elif self.downsample is not None: - x = self.downsample(x) - h = self.downsample(h) - - h = self.conv1(h) - - if self.spatial_dims == 2: - temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] - else: - temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] - h = h + temb - - h = self.norm2(h) - h = self.nonlinearity(h) - h = self.conv2(h) - - return self.skip_connection(x) + h - - -class DownBlock(nn.Module): - """ - Unet's down block containing resnet and downsamplers blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_downsample: if True add downsample block. - resblock_updown: if True use residual blocks for downsampling. - downsample_padding: padding used in the downsampling block. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_downsample: bool = True, - resblock_updown: bool = False, - downsample_padding: int = 1, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - - resnets = [] - - for i in range(num_res_blocks): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - if resblock_updown: - self.downsampler = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - down=True, - ) - else: - self.downsampler = Downsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - ) - else: - self.downsampler = None - - def forward( - self, - hidden_states: torch.Tensor, - temb: torch.Tensor, - context: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, list[torch.Tensor]]: - del context - output_states = [] - - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb) - output_states.append(hidden_states) - - if self.downsampler is not None: - hidden_states = self.downsampler(hidden_states, temb) - output_states.append(hidden_states) - - return hidden_states, output_states - - -class AttnDownBlock(nn.Module): - """ - Unet's down block containing resnet, downsamplers and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_downsample: if True add downsample block. - resblock_updown: if True use residual blocks for downsampling. - downsample_padding: padding used in the downsampling block. - num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_downsample: bool = True, - resblock_updown: bool = False, - downsample_padding: int = 1, - num_head_channels: int = 1, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - - resnets = [] - attentions = [] - - for i in range(num_res_blocks): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - attentions.append( - AttentionBlock( - spatial_dims=spatial_dims, - num_channels=out_channels, - num_head_channels=num_head_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - if resblock_updown: - self.downsampler = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - down=True, - ) - else: - self.downsampler = Downsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - ) - else: - self.downsampler = None - - def forward( - self, - hidden_states: torch.Tensor, - temb: torch.Tensor, - context: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, list[torch.Tensor]]: - del context - output_states = [] - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - output_states.append(hidden_states) - - if self.downsampler is not None: - hidden_states = self.downsampler(hidden_states, temb) - output_states.append(hidden_states) - - return hidden_states, output_states - - -class CrossAttnDownBlock(nn.Module): - """ - Unet's down block containing resnet, downsamplers and cross-attention blocks. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_downsample: if True add downsample block. - resblock_updown: if True use residual blocks for downsampling. - downsample_padding: padding used in the downsampling block. - num_head_channels: number of channels in each attention head. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_downsample: bool = True, - resblock_updown: bool = False, - downsample_padding: int = 1, - num_head_channels: int = 1, - transformer_num_layers: int = 1, - cross_attention_dim: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - - resnets = [] - attentions = [] - - for i in range(num_res_blocks): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - - attentions.append( - SpatialTransformer( - spatial_dims=spatial_dims, - in_channels=out_channels, - num_attention_heads=out_channels // num_head_channels, - num_head_channels=num_head_channels, - num_layers=transformer_num_layers, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout=dropout_cattn, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - if resblock_updown: - self.downsampler = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - down=True, - ) - else: - self.downsampler = Downsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - ) - else: - self.downsampler = None - - def forward( - self, - hidden_states: torch.Tensor, - temb: torch.Tensor, - context: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, list[torch.Tensor]]: - output_states = [] - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=context) - output_states.append(hidden_states) - - if self.downsampler is not None: - hidden_states = self.downsampler(hidden_states, temb) - output_states.append(hidden_states) - - return hidden_states, output_states - - -class AttnMidBlock(nn.Module): - """ - Unet's mid block containing resnet and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - temb_channels: number of timestep embedding channels. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - temb_channels: int, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - num_head_channels: int = 1, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.attention = None - - self.resnet_1 = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - self.attention = AttentionBlock( - spatial_dims=spatial_dims, - num_channels=in_channels, - num_head_channels=num_head_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - - self.resnet_2 = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - - def forward( - self, - hidden_states: torch.Tensor, - temb: torch.Tensor, - context: torch.Tensor | None = None, - ) -> torch.Tensor: - del context - hidden_states = self.resnet_1(hidden_states, temb) - hidden_states = self.attention(hidden_states) - hidden_states = self.resnet_2(hidden_states, temb) - - return hidden_states - - -class CrossAttnMidBlock(nn.Module): - """ - Unet's mid block containing resnet and cross-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - temb_channels: number of timestep embedding channels - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - num_head_channels: number of channels in each attention head. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - temb_channels: int, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - num_head_channels: int = 1, - transformer_num_layers: int = 1, - cross_attention_dim: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, - ) -> None: - super().__init__() - self.attention = None - - self.resnet_1 = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - self.attention = SpatialTransformer( - spatial_dims=spatial_dims, - in_channels=in_channels, - num_attention_heads=in_channels // num_head_channels, - num_head_channels=num_head_channels, - num_layers=transformer_num_layers, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout=dropout_cattn, - ) - self.resnet_2 = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - - def forward( - self, - hidden_states: torch.Tensor, - temb: torch.Tensor, - context: torch.Tensor | None = None, - ) -> torch.Tensor: - hidden_states = self.resnet_1(hidden_states, temb) - hidden_states = self.attention(hidden_states, context=context) - hidden_states = self.resnet_2(hidden_states, temb) - - return hidden_states - - -class UpBlock(nn.Module): - """ - Unet's up block containing resnet and upsamplers blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_upsample: bool = True, - resblock_updown: bool = False, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - resnets = [] - - for i in range(num_res_blocks): - res_skip_channels = ( - in_channels if (i == num_res_blocks - 1) else out_channels - ) - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock( - spatial_dims=spatial_dims, - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - if resblock_updown: - self.upsampler = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - up=True, - ) - else: - self.upsampler = Upsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - ) - else: - self.upsampler = None - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_list: list[torch.Tensor], - temb: torch.Tensor, - context: torch.Tensor | None = None, - ) -> torch.Tensor: - del context - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_list[-1] - res_hidden_states_list = res_hidden_states_list[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - - if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states, temb) - - return hidden_states - - -class AttnUpBlock(nn.Module): - """ - Unet's up block containing resnet, upsamplers, and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_upsample: bool = True, - resblock_updown: bool = False, - num_head_channels: int = 1, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - - resnets = [] - attentions = [] - - for i in range(num_res_blocks): - res_skip_channels = ( - in_channels if (i == num_res_blocks - 1) else out_channels - ) - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock( - spatial_dims=spatial_dims, - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - attentions.append( - AttentionBlock( - spatial_dims=spatial_dims, - num_channels=out_channels, - num_head_channels=num_head_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.attentions = nn.ModuleList(attentions) - - if add_upsample: - if resblock_updown: - self.upsampler = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - up=True, - ) - else: - self.upsampler = Upsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - ) - else: - self.upsampler = None - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_list: list[torch.Tensor], - temb: torch.Tensor, - context: torch.Tensor | None = None, - ) -> torch.Tensor: - del context - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_list[-1] - res_hidden_states_list = res_hidden_states_list[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - - if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states, temb) - - return hidden_states - - -class CrossAttnUpBlock(nn.Module): - """ - Unet's up block containing resnet, upsamplers, and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - num_head_channels: number of channels in each attention head. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers - """ +class DiffusionModelUNetMaisi(DiffusionModelUNet): def __init__( self, spatial_dims: int, in_channels: int, - prev_output_channel: int, out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), norm_num_groups: int = 32, norm_eps: float = 1e-6, - add_upsample: bool = True, resblock_updown: bool = False, - num_head_channels: int = 1, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, transformer_num_layers: int = 1, cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, dropout_cattn: float = 0.0, + input_top_region_index: bool = False, + input_bottom_region_index: bool = False, + input_spacing: bool = False, ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - - resnets = [] - attentions = [] - - for i in range(num_res_blocks): - res_skip_channels = ( - in_channels if (i == num_res_blocks - 1) else out_channels - ) - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock( - spatial_dims=spatial_dims, - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - attentions.append( - SpatialTransformer( - spatial_dims=spatial_dims, - in_channels=out_channels, - num_attention_heads=out_channels // num_head_channels, - num_head_channels=num_head_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout=dropout_cattn, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - if resblock_updown: - self.upsampler = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - up=True, - ) - else: - self.upsampler = Upsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - ) - else: - self.upsampler = None - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_list: list[torch.Tensor], - temb: torch.Tensor, - context: torch.Tensor | None = None, - ) -> torch.Tensor: - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_list[-1] - res_hidden_states_list = res_hidden_states_list[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=context) - - if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states, temb) - - return hidden_states - - -def get_down_block( - spatial_dims: int, - in_channels: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int, - norm_num_groups: int, - norm_eps: float, - add_downsample: bool, - resblock_updown: bool, - with_attn: bool, - with_cross_attn: bool, - num_head_channels: int, - transformer_num_layers: int, - cross_attention_dim: int | None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, -) -> nn.Module: - if with_attn: - return AttnDownBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_downsample=add_downsample, - resblock_updown=resblock_updown, - num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, - ) - elif with_cross_attn: - return CrossAttnDownBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_downsample=add_downsample, - resblock_updown=resblock_updown, - num_head_channels=num_head_channels, - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout_cattn=dropout_cattn, - ) - else: - return DownBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_downsample=add_downsample, - resblock_updown=resblock_updown, - ) - - -def get_mid_block( - spatial_dims: int, - in_channels: int, - temb_channels: int, - norm_num_groups: int, - norm_eps: float, - with_conditioning: bool, - num_head_channels: int, - transformer_num_layers: int, - cross_attention_dim: int | None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, -) -> nn.Module: - if with_conditioning: - return CrossAttnMidBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - num_head_channels=num_head_channels, - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout_cattn=dropout_cattn, - ) - else: - return AttnMidBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, - ) - - -def get_up_block( - spatial_dims: int, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int, - norm_num_groups: int, - norm_eps: float, - add_upsample: bool, - resblock_updown: bool, - with_attn: bool, - with_cross_attn: bool, - num_head_channels: int, - transformer_num_layers: int, - cross_attention_dim: int | None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, -) -> nn.Module: - if with_attn: - return AttnUpBlock( + super().__init__( spatial_dims=spatial_dims, in_channels=in_channels, - prev_output_channel=prev_output_channel, out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_upsample=add_upsample, - resblock_updown=resblock_updown, - num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, - ) - elif with_cross_attn: - return CrossAttnUpBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - prev_output_channel=prev_output_channel, - out_channels=out_channels, - temb_channels=temb_channels, num_res_blocks=num_res_blocks, + num_channels=num_channels, + attention_levels=attention_levels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, - add_upsample=add_upsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, + with_conditioning=with_conditioning, transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, + num_class_embeds=num_class_embeds, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) - else: - return UpBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - prev_output_channel=prev_output_channel, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_upsample=add_upsample, - resblock_updown=resblock_updown, - ) - - -class CustomDiffusionModelUNet(nn.Module): - """ - Unet network with timestep embedding and attention mechanisms for conditioning based on - Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 - and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - num_res_blocks: number of residual blocks (see ResnetBlock) per level. - num_channels: tuple of block output channels. - attention_levels: list of levels to add attention. - norm_num_groups: number of groups for the normalization. - norm_eps: epsilon for the normalization. - resblock_updown: if True use residual blocks for up/downsampling. - num_head_channels: number of channels in each attention head. - with_conditioning: if True add spatial transformers to perform conditioning. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` - classes. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), - num_channels: Sequence[int] = (32, 64, 64, 64), - attention_levels: Sequence[bool] = (False, False, True, True), - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - resblock_updown: bool = False, - num_head_channels: int | Sequence[int] = 8, - with_conditioning: bool = False, - transformer_num_layers: int = 1, - cross_attention_dim: int | None = None, - num_class_embeds: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, - input_top_region_index: bool = False, - input_bottom_region_index: bool = False, - input_spacing: bool = False, - ) -> None: - super().__init__() - if with_conditioning is True and cross_attention_dim is None: - raise ValueError( - "CustomDiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " - "when using with_conditioning." - ) - if cross_attention_dim is not None and with_conditioning is False: - raise ValueError( - "CustomDiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." - ) - if dropout_cattn > 1.0 or dropout_cattn < 0.0: - raise ValueError("Dropout cannot be negative or >1.0!") - - # All number of channels should be multiple of num_groups - if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): - raise ValueError( - "CustomDiffusionModelUNet expects all num_channels being multiple of norm_num_groups" - ) - - if len(num_channels) != len(attention_levels): - raise ValueError( - "CustomDiffusionModelUNet expects num_channels being same size of attention_levels" - ) - - if isinstance(num_head_channels, int): - num_head_channels = ensure_tuple_rep( - num_head_channels, len(attention_levels) - ) - - if len(num_head_channels) != len(attention_levels): - raise ValueError( - "num_head_channels should have the same length as attention_levels. For the i levels without attention," - " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." - ) - - if isinstance(num_res_blocks, int): - num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) - - if len(num_res_blocks) != len(num_channels): - raise ValueError( - "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " - "`num_channels`." - ) - - if use_flash_attention and not has_xformers: - raise ValueError( - "use_flash_attention is True but xformers is not installed." - ) - - if use_flash_attention is True and not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." - ) - - self.in_channels = in_channels - self.block_out_channels = num_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_levels = attention_levels - self.num_head_channels = num_head_channels - self.with_conditioning = with_conditioning - - # input - self.conv_in = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=num_channels[0], - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - - # time - time_embed_dim = num_channels[0] * 4 - self.time_embed = nn.Sequential( - nn.Linear(num_channels[0], time_embed_dim), - nn.SiLU(), - nn.Linear(time_embed_dim, time_embed_dim), - ) - - # class embedding - self.num_class_embeds = num_class_embeds - if num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - + self.input_top_region_index = input_top_region_index self.input_bottom_region_index = input_bottom_region_index self.input_spacing = input_spacing - + + time_embed_dim = num_channels[0] * 4 new_time_embed_dim = time_embed_dim if self.input_top_region_index: - # self.top_region_index_layer = nn.Linear(4, time_embed_dim) self.top_region_index_layer = nn.Sequential( nn.Linear(4, time_embed_dim), nn.SiLU(), @@ -1912,7 +101,6 @@ def __init__( ) new_time_embed_dim += time_embed_dim if self.input_bottom_region_index: - # self.bottom_region_index_layer = nn.Linear(4, time_embed_dim) self.bottom_region_index_layer = nn.Sequential( nn.Linear(4, time_embed_dim), nn.SiLU(), @@ -1920,7 +108,6 @@ def __init__( ) new_time_embed_dim += time_embed_dim if self.input_spacing: - # self.spacing_layer = nn.Linear(3, time_embed_dim) self.spacing_layer = nn.Sequential( nn.Linear(3, time_embed_dim), nn.SiLU(), @@ -1928,111 +115,10 @@ def __init__( ) new_time_embed_dim += time_embed_dim - # down - self.down_blocks = nn.ModuleList([]) - output_channel = num_channels[0] - for i in range(len(num_channels)): - input_channel = output_channel - output_channel = num_channels[i] - is_final_block = i == len(num_channels) - 1 - - down_block = get_down_block( - spatial_dims=spatial_dims, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=new_time_embed_dim, - num_res_blocks=num_res_blocks[i], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_downsample=not is_final_block, - resblock_updown=resblock_updown, - with_attn=(attention_levels[i] and not with_conditioning), - with_cross_attn=(attention_levels[i] and with_conditioning), - num_head_channels=num_head_channels[i], - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout_cattn=dropout_cattn, - ) - - self.down_blocks.append(down_block) - - # mid - self.middle_block = get_mid_block( - spatial_dims=spatial_dims, - in_channels=num_channels[-1], - temb_channels=new_time_embed_dim, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - with_conditioning=with_conditioning, - num_head_channels=num_head_channels[-1], - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout_cattn=dropout_cattn, - ) - - # up - self.up_blocks = nn.ModuleList([]) - reversed_block_out_channels = list(reversed(num_channels)) - reversed_num_res_blocks = list(reversed(num_res_blocks)) - reversed_attention_levels = list(reversed(attention_levels)) - reversed_num_head_channels = list(reversed(num_head_channels)) - output_channel = reversed_block_out_channels[0] - for i in range(len(reversed_block_out_channels)): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[ - min(i + 1, len(num_channels) - 1) - ] - - is_final_block = i == len(num_channels) - 1 - - up_block = get_up_block( - spatial_dims=spatial_dims, - in_channels=input_channel, - prev_output_channel=prev_output_channel, - out_channels=output_channel, - temb_channels=new_time_embed_dim, - num_res_blocks=reversed_num_res_blocks[i] + 1, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_upsample=not is_final_block, - resblock_updown=resblock_updown, - with_attn=(reversed_attention_levels[i] and not with_conditioning), - with_cross_attn=(reversed_attention_levels[i] and with_conditioning), - num_head_channels=reversed_num_head_channels[i], - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout_cattn=dropout_cattn, - ) - - self.up_blocks.append(up_block) - - # out - self.out = nn.Sequential( - nn.GroupNorm( - num_groups=norm_num_groups, - num_channels=num_channels[0], - eps=norm_eps, - affine=True, - ), + self.time_embed = nn.Sequential( + nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), - zero_module( - Convolution( - spatial_dims=spatial_dims, - in_channels=num_channels[0], - out_channels=out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ), + nn.Linear(time_embed_dim, new_time_embed_dim), ) def forward( @@ -2047,108 +133,37 @@ def forward( bottom_region_index_tensor: torch.Tensor | None = None, spacing_tensor: torch.Tensor | None = None, ) -> torch.Tensor: - """ - Args: - x: input tensor (N, C, SpatialDims). - timesteps: timestep tensor (N,). - context: context tensor (N, 1, ContextDim). - class_labels: context tensor (N, ). - down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). - mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). - """ - # 1. time - t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=x.dtype) + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]).to(dtype=x.dtype) emb = self.time_embed(t_emb) - # print(f't_emb: {t_emb}; timesteps {timesteps}.') - # print(f'emb: {torch.norm(emb, p=2)}; t_emb: {torch.norm(t_emb, p=2)}') - # print(f"emb: {torch.norm(emb)}") - # 2. class - if self.num_class_embeds is not None: - if class_labels is None: - raise ValueError( - "class_labels should be provided when num_class_embeds > 0" - ) - class_emb = self.class_embedding(class_labels) - class_emb = class_emb.to(dtype=x.dtype) + if self.num_class_embeds is not None and class_labels is not None: + class_emb = self.class_embedding(class_labels).to(dtype=x.dtype) emb = emb + class_emb - # 3. input - if self.input_top_region_index: - _emb = self.top_region_index_layer(top_region_index_tensor) - # print(f"top_region_index_layer: {torch.norm(_emb)} {_emb.size()}") - # emb = emb + _emb.to(dtype=x.dtype) - emb = torch.cat((emb, _emb), dim=1) - # print(f'emb: {emb.size()}, {torch.norm(emb, p=2)}; top_region_index_tensor: {torch.norm(_emb, p=2)}') - if self.input_bottom_region_index: - _emb = self.bottom_region_index_layer(bottom_region_index_tensor) - # print(f"bottom_region_index_layer: {torch.norm(_emb)}") - # emb = emb + _emb.to(dtype=x.dtype) - emb = torch.cat((emb, _emb), dim=1) - # print(f'emb: {emb.size()}, {torch.norm(emb, p=2)}; bottom_region_index_tensor: {torch.norm(_emb, p=2)}') - if self.input_spacing: - _emb = self.spacing_layer(spacing_tensor) - # print(f"spacing_layer: {torch.norm(_emb)}") - # emb = emb + _emb.to(dtype=x.dtype) - emb = torch.cat((emb, _emb), dim=1) - # print(f'emb: {emb.size()}, {torch.norm(emb, p=2)}; spacing_tensor: {torch.norm(spacing_tensor, p=2)}') + if self.input_top_region_index and top_region_index_tensor is not None: + emb = torch.cat((emb, self.top_region_index_layer(top_region_index_tensor)), dim=1) + if self.input_bottom_region_index and bottom_region_index_tensor is not None: + emb = torch.cat((emb, self.bottom_region_index_layer(bottom_region_index_tensor)), dim=1) + if self.input_spacing and spacing_tensor is not None: + emb = torch.cat((emb, self.spacing_layer(spacing_tensor)), dim=1) - # 3. initial convolution h = self.conv_in(x) - # print(f"x: {torch.norm(x)}; h: {torch.norm(h)}") - - # 4. down - if context is not None and self.with_conditioning is False: - raise ValueError( - "model should have with_conditioning = True if context is provided" - ) - down_block_res_samples: list[torch.Tensor] = [h] + down_block_res_samples = [h] for downsample_block in self.down_blocks: - h, res_samples = downsample_block( - hidden_states=h, temb=emb, context=context - ) - for residual in res_samples: - down_block_res_samples.append(residual) + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + down_block_res_samples.extend(res_samples) - # Additional residual conections for Controlnets if down_block_additional_residuals is not None: - new_down_block_res_samples = () - for down_block_res_sample, down_block_additional_residual in zip( - down_block_res_samples, down_block_additional_residuals - ): - down_block_res_sample = ( - down_block_res_sample + down_block_additional_residual - ) - new_down_block_res_samples += (down_block_res_sample,) + down_block_res_samples = [res + add_res for res, add_res in zip(down_block_res_samples, down_block_additional_residuals)] - down_block_res_samples = new_down_block_res_samples - - # 5. mid h = self.middle_block(hidden_states=h, temb=emb, context=context) - - # Additional residual conections for Controlnets if mid_block_additional_residual is not None: h = h + mid_block_additional_residual - # 6. up for upsample_block in self.up_blocks: - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[ - : -len(upsample_block.resnets) - ] - h = upsample_block( - hidden_states=h, - res_hidden_states_list=res_samples, - temb=emb, - context=context, - ) + res_samples = down_block_res_samples[-len(upsample_block.resnets):] + down_block_res_samples = down_block_res_samples[:-len(upsample_block.resnets)] + h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) - # 7. output block h = self.out(h) - return h From f1eeeb3b6cd88ccf308ae7218aab49a340a3d879 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Jun 2024 21:29:59 +0000 Subject: [PATCH 04/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../generation/maisi/networks/diffusion_model_unet_maisi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index b4b4bc878c..1ddff9648a 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -86,11 +86,11 @@ def __init__( use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) - + self.input_top_region_index = input_top_region_index self.input_bottom_region_index = input_bottom_region_index self.input_spacing = input_spacing - + time_embed_dim = num_channels[0] * 4 new_time_embed_dim = time_embed_dim if self.input_top_region_index: From 76e3f4aaa58df94d52ff7894eebce90c159b26cb Mon Sep 17 00:00:00 2001 From: Dong Yang Date: Fri, 21 Jun 2024 15:38:02 -0600 Subject: [PATCH 05/36] update Signed-off-by: Dong Yang --- .../networks/diffusion_model_unet_maisi.py | 56 ++++++++++++++++--- 1 file changed, 48 insertions(+), 8 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index b4b4bc878c..1909118ba5 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -38,12 +38,39 @@ __all__ = ["DiffusionModelUNetMaisi"] - from monai.networks.nets.diffusion_model_unet import DiffusionModelUNet from generative.networks.nets.diffusion_model_unet import get_timestep_embedding class DiffusionModelUNetMaisi(DiffusionModelUNet): + """ + DiffusionModelUNetMaisi extends the DiffusionModelUNet class to support additional + input features like region indices and spacing. This class is specifically designed + for enhanced image synthesis in medical applications. + + Args: + spatial_dims: Number of spatial dimensions (2 or 3). + in_channels: Number of input channels. + out_channels: Number of output channels. + num_res_blocks: Number of residual blocks per level. + num_channels: Tuple of block output channels. + attention_levels: List indicating which levels have attention. + norm_num_groups: Number of groups for group normalization. + norm_eps: Epsilon for group normalization. + resblock_updown: If True, use residual blocks for up/downsampling. + num_head_channels: Number of channels in each attention head. + with_conditioning: If True, add spatial transformers for conditioning. + transformer_num_layers: Number of layers of Transformer blocks to use. + cross_attention_dim: Number of context dimensions for cross-attention. + num_class_embeds: Number of class embeddings for class-conditional generation. + upcast_attention: If True, upcast attention operations to full precision. + use_flash_attention: If True, use flash attention for memory efficient attention. + dropout_cattn: Dropout value for cross-attention layers. + input_top_region_index: If True, include top region index in the input. + input_bottom_region_index: If True, include bottom region index in the input. + input_spacing: If True, include spacing information in the input. + """ + def __init__( self, spatial_dims: int, @@ -86,11 +113,11 @@ def __init__( use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) - + self.input_top_region_index = input_top_region_index self.input_bottom_region_index = input_bottom_region_index self.input_spacing = input_spacing - + time_embed_dim = num_channels[0] * 4 new_time_embed_dim = time_embed_dim if self.input_top_region_index: @@ -133,6 +160,23 @@ def forward( bottom_region_index_tensor: torch.Tensor | None = None, spacing_tensor: torch.Tensor | None = None, ) -> torch.Tensor: + """ + Forward pass through the DiffusionModelUNetMaisi. + + Args: + x: Input tensor of shape (N, C, SpatialDims). + timesteps: Timestep tensor of shape (N,). + context: Optional context tensor of shape (N, 1, ContextDim). + class_labels: Optional class label tensor of shape (N,). + down_block_additional_residuals: Optional additional residual tensors for down blocks. + mid_block_additional_residual: Optional additional residual tensor for mid block. + top_region_index_tensor: Optional tensor for top region index of shape (N, 4). + bottom_region_index_tensor: Optional tensor for bottom region index of shape (N, 4). + spacing_tensor: Optional tensor for spacing information of shape (N, 3). + + Returns: + Output tensor of shape (N, C, SpatialDims). + """ t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]).to(dtype=x.dtype) emb = self.time_embed(t_emb) @@ -162,8 +206,4 @@ def forward( for upsample_block in self.up_blocks: res_samples = down_block_res_samples[-len(upsample_block.resnets):] - down_block_res_samples = down_block_res_samples[:-len(upsample_block.resnets)] - h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) - - h = self.out(h) - return h + down_block_res_samples = down_block_res_samples[:-len(upsample_block.resnets From c1749e0ef8064e1e4be08676c221f6b88b6e14c2 Mon Sep 17 00:00:00 2001 From: Dong Yang Date: Fri, 21 Jun 2024 15:43:14 -0600 Subject: [PATCH 06/36] update inferers Signed-off-by: Dong Yang --- .../maisi/inferers/inferer_maisi.py | 229 +----------------- 1 file changed, 2 insertions(+), 227 deletions(-) diff --git a/monai/apps/generation/maisi/inferers/inferer_maisi.py b/monai/apps/generation/maisi/inferers/inferer_maisi.py index 5b4aec5950..1c8ee6e942 100644 --- a/monai/apps/generation/maisi/inferers/inferer_maisi.py +++ b/monai/apps/generation/maisi/inferers/inferer_maisi.py @@ -26,7 +26,7 @@ IF_PROFILE = False -class DiffusionInferer(Inferer): +class DiffusionInfererMaisi(Inferer): """ DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass for a training iteration, and sample from the model. @@ -381,7 +381,7 @@ def _get_decoder_log_likelihood( return log_probs -class LatentDiffusionInferer(DiffusionInferer): +class LatentDiffusionInfererMaisi(DiffusionInfererMaisi): """ LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can be used to perform a signal forward pass for a training iteration, and sample from the model. @@ -561,228 +561,3 @@ def get_likelihood( intermediates = [resizer(x) for x in intermediates] outputs = (outputs[0], intermediates) return outputs - - -class VQVAETransformerInferer(Inferer): - """ - Class to perform inference with a VQVAE + Transformer model. - """ - - def __init__(self) -> None: - Inferer.__init__(self) - - def __call__( - self, - inputs: torch.Tensor, - vqvae_model: Callable[..., torch.Tensor], - transformer_model: Callable[..., torch.Tensor], - ordering: Callable[..., torch.Tensor], - condition: torch.Tensor | None = None, - return_latent: bool = False, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, tuple]: - """ - Implements the forward pass for a supervised training iteration. - - Args: - inputs: input image to which the latent representation will be extracted. - vqvae_model: first stage model. - transformer_model: autoregressive transformer model. - ordering: ordering of the quantised latent representation. - return_latent: also return latent sequence and spatial dim of the latent. - condition: conditioning for network input. - """ - with torch.no_grad(): - latent = vqvae_model.index_quantize(inputs) - - latent_spatial_dim = tuple(latent.shape[1:]) - latent = latent.reshape(latent.shape[0], -1) - latent = latent[:, ordering.get_sequence_ordering()] - - # get the targets for the loss - target = latent.clone() - # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token. - # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens. - latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) - # crop the last token as we do not need the probability of the token that follows it - latent = latent[:, :-1] - latent = latent.long() - - # train on a part of the sequence if it is longer than max_seq_length - seq_len = latent.shape[1] - max_seq_len = transformer_model.max_seq_len - if max_seq_len < seq_len: - start = torch.randint( - low=0, high=seq_len + 1 - max_seq_len, size=(1,) - ).item() - else: - start = 0 - prediction = transformer_model( - x=latent[:, start : start + max_seq_len], context=condition - ) - if return_latent: - return ( - prediction, - target[:, start : start + max_seq_len], - latent_spatial_dim, - ) - else: - return prediction - - @torch.no_grad() - def sample( - self, - latent_spatial_dim: Sequence[int, int, int] | Sequence[int, int], - starting_tokens: torch.Tensor, - vqvae_model: Callable[..., torch.Tensor], - transformer_model: Callable[..., torch.Tensor], - ordering: Callable[..., torch.Tensor], - conditioning: torch.Tensor | None = None, - temperature: float = 1.0, - top_k: int | None = None, - verbose: bool = True, - ) -> torch.Tensor: - """ - Sampling function for the VQVAE + Transformer model. - - Args: - latent_spatial_dim: shape of the sampled image. - starting_tokens: starting tokens for the sampling. It must be vqvae_model.num_embeddings value. - vqvae_model: first stage model. - transformer_model: model to sample from. - conditioning: Conditioning for network input. - temperature: temperature for sampling. - top_k: top k sampling. - verbose: if true, prints the progression bar of the sampling process. - """ - seq_len = math.prod(latent_spatial_dim) - - if verbose and has_tqdm: - progress_bar = tqdm(range(seq_len)) - else: - progress_bar = iter(range(seq_len)) - - latent_seq = starting_tokens.long() - for _ in progress_bar: - # if the sequence context is growing too long we must crop it at block_size - if latent_seq.size(1) <= transformer_model.max_seq_len: - idx_cond = latent_seq - else: - idx_cond = latent_seq[:, -transformer_model.max_seq_len :] - - # forward the model to get the logits for the index in the sequence - logits = transformer_model(x=idx_cond, context=conditioning) - # pluck the logits at the final step and scale by desired temperature - logits = logits[:, -1, :] / temperature - # optionally crop the logits to only the top k options - if top_k is not None: - v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - logits[logits < v[:, [-1]]] = -float("Inf") - # apply softmax to convert logits to (normalized) probabilities - probs = F.softmax(logits, dim=-1) - # remove the chance to be sampled the BOS token - probs[:, vqvae_model.num_embeddings] = 0 - # sample from the distribution - idx_next = torch.multinomial(probs, num_samples=1) - # append sampled index to the running sequence and continue - latent_seq = torch.cat((latent_seq, idx_next), dim=1) - - latent_seq = latent_seq[:, 1:] - latent_seq = latent_seq[:, ordering.get_revert_sequence_ordering()] - latent = latent_seq.reshape((starting_tokens.shape[0],) + latent_spatial_dim) - - return vqvae_model.decode_samples(latent) - - @torch.no_grad() - def get_likelihood( - self, - inputs: torch.Tensor, - vqvae_model: Callable[..., torch.Tensor], - transformer_model: Callable[..., torch.Tensor], - ordering: Callable[..., torch.Tensor], - condition: torch.Tensor | None = None, - resample_latent_likelihoods: bool = False, - resample_interpolation_mode: str = "nearest", - verbose: bool = False, - ) -> torch.Tensor: - """ - Computes the log-likelihoods of the latent representations of the input. - - Args: - inputs: input images, NxCxHxW[xD] - vqvae_model: first stage model. - transformer_model: autoregressive transformer model. - ordering: ordering of the quantised latent representation. - condition: conditioning for network input. - resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial - dimension as the input images. - resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', - or 'trilinear; - verbose: if true, prints the progression bar of the sampling process. - - """ - if resample_latent_likelihoods and resample_interpolation_mode not in ( - "nearest", - "bilinear", - "trilinear", - ): - raise ValueError( - f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" - ) - - with torch.no_grad(): - latent = vqvae_model.index_quantize(inputs) - - latent_spatial_dim = tuple(latent.shape[1:]) - latent = latent.reshape(latent.shape[0], -1) - latent = latent[:, ordering.get_sequence_ordering()] - seq_len = math.prod(latent_spatial_dim) - - # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token. - # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens. - latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) - latent = latent.long() - - # get the first batch, up to max_seq_length, efficiently - logits = transformer_model( - x=latent[:, : transformer_model.max_seq_len], context=condition - ) - probs = F.softmax(logits, dim=-1) - # target token for each set of logits is the next token along - target = latent[:, 1:] - probs = torch.gather( - probs, 2, target[:, : transformer_model.max_seq_len].unsqueeze(2) - ).squeeze(2) - - # if we have not covered the full sequence we continue with inefficient looping - if probs.shape[1] < target.shape[1]: - if verbose and has_tqdm: - progress_bar = tqdm(range(transformer_model.max_seq_len, seq_len)) - else: - progress_bar = iter(range(transformer_model.max_seq_len, seq_len)) - - for i in progress_bar: - idx_cond = latent[:, i + 1 - transformer_model.max_seq_len : i + 1] - # forward the model to get the logits for the index in the sequence - logits = transformer_model(x=idx_cond, context=condition) - # pluck the logits at the final step - logits = logits[:, -1, :] - # apply softmax to convert logits to (normalized) probabilities - p = F.softmax(logits, dim=-1) - # select correct values and append - p = torch.gather(p, 1, target[:, i].unsqueeze(1)) - - probs = torch.cat((probs, p), dim=1) - - # convert to log-likelihood - probs = torch.log(probs) - - # reshape - probs = probs[:, ordering.get_revert_sequence_ordering()] - probs_reshaped = probs.reshape((inputs.shape[0],) + latent_spatial_dim) - if resample_latent_likelihoods: - resizer = nn.Upsample( - size=inputs.shape[2:], mode=resample_interpolation_mode - ) - probs_reshaped = resizer(probs_reshaped[:, None, ...]) - - return probs_reshaped From 042290823e251da704017a539430f2a9985323de Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Jun 2024 21:43:42 +0000 Subject: [PATCH 07/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/generation/maisi/inferers/inferer_maisi.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/apps/generation/maisi/inferers/inferer_maisi.py b/monai/apps/generation/maisi/inferers/inferer_maisi.py index 1c8ee6e942..e9fea3285f 100644 --- a/monai/apps/generation/maisi/inferers/inferer_maisi.py +++ b/monai/apps/generation/maisi/inferers/inferer_maisi.py @@ -12,11 +12,10 @@ from __future__ import annotations import math -from collections.abc import Callable, Sequence +from collections.abc import Callable import torch import torch.nn as nn -import torch.nn.functional as F from monai.inferers import Inferer from monai.utils import optional_import From 926eb7f4de98585ae1b4ec2b608cebc1b63fe748 Mon Sep 17 00:00:00 2001 From: Dong Yang Date: Fri, 21 Jun 2024 16:04:20 -0600 Subject: [PATCH 08/36] update inferers Signed-off-by: Dong Yang --- .../maisi/inferers/inferer_maisi.py | 240 +----------------- 1 file changed, 12 insertions(+), 228 deletions(-) diff --git a/monai/apps/generation/maisi/inferers/inferer_maisi.py b/monai/apps/generation/maisi/inferers/inferer_maisi.py index 1c8ee6e942..dae3fa8c08 100644 --- a/monai/apps/generation/maisi/inferers/inferer_maisi.py +++ b/monai/apps/generation/maisi/inferers/inferer_maisi.py @@ -26,19 +26,17 @@ IF_PROFILE = False -class DiffusionInfererMaisi(Inferer): +class DiffusionInfererMaisi(DiffusionInferer): """ - DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass + DiffusionInfererMaisi takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass for a training iteration, and sample from the model. - Args: scheduler: diffusion scheduler. """ def __init__(self, scheduler: nn.Module) -> None: - Inferer.__init__(self) - self.scheduler = scheduler + super().__init__(scheduler=scheduler) def __call__( self, @@ -62,6 +60,9 @@ def __call__( timesteps: random timesteps. condition: Conditioning for network input. mode: Conditioning mode for the network. + top_region_index_tensor: tensor for top region index. + bottom_region_index_tensor: tensor for bottom region index. + spacing_tensor: tensor for spacing. """ if mode not in ["crossattn", "concat"]: raise NotImplementedError(f"{mode} condition is not supported") @@ -102,12 +103,15 @@ def sample( Args: input_noise: random noise, of the same shape as the desired sample. diffusion_model: model to sample from. - scheduler: diffusion scheduler. If none provided will use the class attribute scheduler - save_intermediates: whether to return intermediates along the sampling change - intermediate_steps: if save_intermediates is True, saves every n steps + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: whether to return intermediates along the sampling change. + intermediate_steps: if save_intermediates is True, saves every n steps. conditioning: Conditioning for network input. mode: Conditioning mode for the network. verbose: if true, prints the progression bar of the sampling process. + top_region_index_tensor: tensor for top region index. + bottom_region_index_tensor: tensor for bottom region index. + spacing_tensor: tensor for spacing. """ if mode not in ["crossattn", "concat"]: raise NotImplementedError(f"{mode} condition is not supported") @@ -149,236 +153,16 @@ def sample( spacing_tensor=spacing_tensor, ) - if IF_PROFILE: - torch.cuda.nvtx.range_pop() - - # diff = torch.norm(model_output).cpu().item() - # print(diff) - # with open("diff.txt", "a") as file: - # file.write(f"{diff}\n") - # 2. compute previous image: x_t -> x_t-1 image, _ = scheduler.step(model_output, t, image) if save_intermediates and t % intermediate_steps == 0: intermediates.append(image) - if IF_PROFILE: - torch.cuda.cudart().cudaProfilerStop() - if save_intermediates: return image, intermediates else: return image - @torch.no_grad() - def get_likelihood( - self, - inputs: torch.Tensor, - diffusion_model: Callable[..., torch.Tensor], - scheduler: Callable[..., torch.Tensor] | None = None, - save_intermediates: bool | None = False, - conditioning: torch.Tensor | None = None, - mode: str = "crossattn", - original_input_range: tuple | None = (0, 255), - scaled_input_range: tuple | None = (0, 1), - verbose: bool = True, - ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: - """ - Computes the log-likelihoods for an input. - - Args: - inputs: input images, NxCxHxW[xD] - diffusion_model: model to compute likelihood from - scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. - save_intermediates: save the intermediate spatial KL maps - conditioning: Conditioning for network input. - mode: Conditioning mode for the network. - original_input_range: the [min,max] intensity range of the input data before any scaling was applied. - scaled_input_range: the [min,max] intensity range of the input data after scaling. - verbose: if true, prints the progression bar of the sampling process. - """ - - if not scheduler: - scheduler = self.scheduler - if scheduler._get_name() != "DDPMScheduler": - raise NotImplementedError( - f"Likelihood computation is only compatible with DDPMScheduler," - f" you are using {scheduler._get_name()}" - ) - if mode not in ["crossattn", "concat"]: - raise NotImplementedError(f"{mode} condition is not supported") - if verbose and has_tqdm: - progress_bar = tqdm(scheduler.timesteps) - else: - progress_bar = iter(scheduler.timesteps) - intermediates = [] - noise = torch.randn_like(inputs).to(inputs.device) - total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) - for t in progress_bar: - timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() - noisy_image = self.scheduler.add_noise( - original_samples=inputs, noise=noise, timesteps=timesteps - ) - if mode == "concat": - noisy_image = torch.cat([noisy_image, conditioning], dim=1) - model_output = diffusion_model( - noisy_image, timesteps=timesteps, context=None - ) - else: - model_output = diffusion_model( - x=noisy_image, timesteps=timesteps, context=conditioning - ) - # get the model's predicted mean, and variance if it is predicted - if model_output.shape[1] == inputs.shape[ - 1 - ] * 2 and scheduler.variance_type in ["learned", "learned_range"]: - model_output, predicted_variance = torch.split( - model_output, inputs.shape[1], dim=1 - ) - else: - predicted_variance = None - - # 1. compute alphas, betas - alpha_prod_t = scheduler.alphas_cumprod[t] - alpha_prod_t_prev = ( - scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one - ) - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev - - # 2. compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if scheduler.prediction_type == "epsilon": - pred_original_sample = ( - noisy_image - beta_prod_t ** (0.5) * model_output - ) / alpha_prod_t ** (0.5) - elif scheduler.prediction_type == "sample": - pred_original_sample = model_output - elif scheduler.prediction_type == "v_prediction": - pred_original_sample = (alpha_prod_t**0.5) * noisy_image - ( - beta_prod_t**0.5 - ) * model_output - # 3. Clip "predicted x_0" - if scheduler.clip_sample: - pred_original_sample = torch.clamp(pred_original_sample, -1, 1) - - # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t - # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_original_sample_coeff = ( - alpha_prod_t_prev ** (0.5) * scheduler.betas[t] - ) / beta_prod_t - current_sample_coeff = ( - scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t - ) - - # 5. Compute predicted previous sample ยต_t - # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - predicted_mean = ( - pred_original_sample_coeff * pred_original_sample - + current_sample_coeff * noisy_image - ) - - # get the posterior mean and variance - posterior_mean = scheduler._get_mean( - timestep=t, x_0=inputs, x_t=noisy_image - ) - posterior_variance = scheduler._get_variance( - timestep=t, predicted_variance=predicted_variance - ) - - log_posterior_variance = torch.log(posterior_variance) - log_predicted_variance = ( - torch.log(predicted_variance) - if predicted_variance - else log_posterior_variance - ) - - if t == 0: - # compute -log p(x_0|x_1) - kl = -self._get_decoder_log_likelihood( - inputs=inputs, - means=predicted_mean, - log_scales=0.5 * log_predicted_variance, - original_input_range=original_input_range, - scaled_input_range=scaled_input_range, - ) - else: - # compute kl between two normals - kl = 0.5 * ( - -1.0 - + log_predicted_variance - - log_posterior_variance - + torch.exp(log_posterior_variance - log_predicted_variance) - + ((posterior_mean - predicted_mean) ** 2) - * torch.exp(-log_predicted_variance) - ) - total_kl += kl.view(kl.shape[0], -1).mean(axis=1) - if save_intermediates: - intermediates.append(kl.cpu()) - - if save_intermediates: - return total_kl, intermediates - else: - return total_kl - - def _approx_standard_normal_cdf(self, x): - """ - A fast approximation of the cumulative distribution function of the - standard normal. Code adapted from https://github.com/openai/improved-diffusion. - """ - - return 0.5 * ( - 1.0 - + torch.tanh( - torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) - * (x + 0.044715 * torch.pow(x, 3)) - ) - ) - - def _get_decoder_log_likelihood( - self, - inputs: torch.Tensor, - means: torch.Tensor, - log_scales: torch.Tensor, - original_input_range: tuple | None = (0, 255), - scaled_input_range: tuple | None = (0, 1), - ) -> torch.Tensor: - """ - Compute the log-likelihood of a Gaussian distribution discretizing to a - given image. Code adapted from https://github.com/openai/improved-diffusion. - - Args: - input: the target images. It is assumed that this was uint8 values, - rescaled to the range [-1, 1]. - means: the Gaussian mean Tensor. - log_scales: the Gaussian log stddev Tensor. - original_input_range: the [min,max] intensity range of the input data before any scaling was applied. - scaled_input_range: the [min,max] intensity range of the input data after scaling. - """ - assert inputs.shape == means.shape - bin_width = (scaled_input_range[1] - scaled_input_range[0]) / ( - original_input_range[1] - original_input_range[0] - ) - centered_x = inputs - means - inv_stdv = torch.exp(-log_scales) - plus_in = inv_stdv * (centered_x + bin_width / 2) - cdf_plus = self._approx_standard_normal_cdf(plus_in) - min_in = inv_stdv * (centered_x - bin_width / 2) - cdf_min = self._approx_standard_normal_cdf(min_in) - log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) - log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) - cdf_delta = cdf_plus - cdf_min - log_probs = torch.where( - inputs < -0.999, - log_cdf_plus, - torch.where( - inputs > 0.999, - log_one_minus_cdf_min, - torch.log(cdf_delta.clamp(min=1e-12)), - ), - ) - assert log_probs.shape == inputs.shape - return log_probs class LatentDiffusionInfererMaisi(DiffusionInfererMaisi): From f99f2f33503a3614c31ba7b600778b1f6cc6b5cf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Jun 2024 22:04:50 +0000 Subject: [PATCH 09/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/generation/maisi/inferers/inferer_maisi.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/monai/apps/generation/maisi/inferers/inferer_maisi.py b/monai/apps/generation/maisi/inferers/inferer_maisi.py index 7653699b7c..9a3aab0197 100644 --- a/monai/apps/generation/maisi/inferers/inferer_maisi.py +++ b/monai/apps/generation/maisi/inferers/inferer_maisi.py @@ -11,12 +11,10 @@ from __future__ import annotations -import math from collections.abc import Callable import torch import torch.nn as nn -from monai.inferers import Inferer from monai.utils import optional_import tqdm, has_tqdm = optional_import("tqdm", name="tqdm") From 8ec8848eea07f97c339bb2b3a46eab7cd0238b37 Mon Sep 17 00:00:00 2001 From: Dong Yang Date: Fri, 21 Jun 2024 16:09:56 -0600 Subject: [PATCH 10/36] update inferers Signed-off-by: Dong Yang --- .../maisi/inferers/inferer_maisi.py | 45 +++++++++++-------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/monai/apps/generation/maisi/inferers/inferer_maisi.py b/monai/apps/generation/maisi/inferers/inferer_maisi.py index 7653699b7c..7bc06dea71 100644 --- a/monai/apps/generation/maisi/inferers/inferer_maisi.py +++ b/monai/apps/generation/maisi/inferers/inferer_maisi.py @@ -11,18 +11,19 @@ from __future__ import annotations -import math from collections.abc import Callable +from functools import partial import torch import torch.nn as nn -from monai.inferers import Inferer +from monai.data import decollate_batch from monai.utils import optional_import tqdm, has_tqdm = optional_import("tqdm", name="tqdm") +from generative.inferers.inferer import DiffusionInferer +from generative.networks.nets import VQVAE, SPADEAutoencoderKL, SPADEDiffusionModelUNet -IF_PROFILE = False class DiffusionInfererMaisi(DiffusionInferer): @@ -124,12 +125,7 @@ def sample( progress_bar = iter(scheduler.timesteps) intermediates = [] - if IF_PROFILE: - torch.cuda.cudart().cudaProfilerStart() - for t in progress_bar: - if IF_PROFILE: - torch.cuda.nvtx.range_push("forward") # 1. predict noise model_output if mode == "concat": @@ -298,6 +294,8 @@ def get_likelihood( verbose: bool = True, resample_latent_likelihoods: bool = False, resample_interpolation_mode: str = "nearest", + seg: torch.Tensor | None = None, + quantized: bool = True, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Computes the log-likelihoods of the latent representations of the input. @@ -317,17 +315,29 @@ def get_likelihood( dimension as the input images. resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', or 'trilinear; + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + quantized: if autoencoder_model is a VQVAE, quantized controls whether the latents to the LDM + are quantized or not. """ - if resample_latent_likelihoods and resample_interpolation_mode not in ( - "nearest", - "bilinear", - "trilinear", - ): + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): raise ValueError( f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" ) - latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor - outputs = super().get_likelihood( + + autoencode = autoencoder_model.encode_stage_2_inputs + if isinstance(autoencoder_model, VQVAE): + autoencode = partial(autoencoder_model.encode_stage_2_inputs, quantized=quantized) + latents = autoencode(inputs) * self.scale_factor + + if self.ldm_latent_shape is not None: + latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) + + get_likelihood = super().get_likelihood + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + get_likelihood = partial(super().get_likelihood, seg=seg) + + outputs = get_likelihood( inputs=latents, diffusion_model=diffusion_model, scheduler=scheduler, @@ -336,11 +346,10 @@ def get_likelihood( mode=mode, verbose=verbose, ) + if save_intermediates and resample_latent_likelihoods: intermediates = outputs[1] - resizer = nn.Upsample( - size=inputs.shape[2:], mode=resample_interpolation_mode - ) + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) intermediates = [resizer(x) for x in intermediates] outputs = (outputs[0], intermediates) return outputs From 744d8293266077e52a2432e90d3325363f4191dd Mon Sep 17 00:00:00 2001 From: Dong Yang Date: Fri, 21 Jun 2024 21:06:07 -0600 Subject: [PATCH 11/36] update inferers Signed-off-by: Dong Yang --- monai/apps/generation/maisi/inferers/inferer_maisi.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/monai/apps/generation/maisi/inferers/inferer_maisi.py b/monai/apps/generation/maisi/inferers/inferer_maisi.py index e54b0ae3d1..7bc06dea71 100644 --- a/monai/apps/generation/maisi/inferers/inferer_maisi.py +++ b/monai/apps/generation/maisi/inferers/inferer_maisi.py @@ -16,10 +16,7 @@ import torch import torch.nn as nn -<<<<<<< HEAD from monai.data import decollate_batch -======= ->>>>>>> f99f2f33503a3614c31ba7b600778b1f6cc6b5cf from monai.utils import optional_import tqdm, has_tqdm = optional_import("tqdm", name="tqdm") From de414bb7fe70dc538d251ba5f52c4b2273964c76 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 22 Jun 2024 03:06:33 +0000 Subject: [PATCH 12/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/generation/maisi/inferers/inferer_maisi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/generation/maisi/inferers/inferer_maisi.py b/monai/apps/generation/maisi/inferers/inferer_maisi.py index 7bc06dea71..09d500efe1 100644 --- a/monai/apps/generation/maisi/inferers/inferer_maisi.py +++ b/monai/apps/generation/maisi/inferers/inferer_maisi.py @@ -22,7 +22,7 @@ tqdm, has_tqdm = optional_import("tqdm", name="tqdm") from generative.inferers.inferer import DiffusionInferer -from generative.networks.nets import VQVAE, SPADEAutoencoderKL, SPADEDiffusionModelUNet +from generative.networks.nets import VQVAE, SPADEDiffusionModelUNet From 8fc994cd56998a74d9e22addc69a57482b4da483 Mon Sep 17 00:00:00 2001 From: Dong Yang Date: Fri, 21 Jun 2024 21:06:45 -0600 Subject: [PATCH 13/36] update inferers Signed-off-by: Dong Yang --- monai/apps/generation/maisi/inferers/inferer_maisi.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/apps/generation/maisi/inferers/inferer_maisi.py b/monai/apps/generation/maisi/inferers/inferer_maisi.py index 7bc06dea71..a7d0915510 100644 --- a/monai/apps/generation/maisi/inferers/inferer_maisi.py +++ b/monai/apps/generation/maisi/inferers/inferer_maisi.py @@ -22,8 +22,7 @@ tqdm, has_tqdm = optional_import("tqdm", name="tqdm") from generative.inferers.inferer import DiffusionInferer -from generative.networks.nets import VQVAE, SPADEAutoencoderKL, SPADEDiffusionModelUNet - +from generative.networks.nets import VQVAE, SPADEDiffusionModelUNet class DiffusionInfererMaisi(DiffusionInferer): From 6638d17b00f5494195270ff0b59a432dc04bfd61 Mon Sep 17 00:00:00 2001 From: Dong Yang Date: Fri, 21 Jun 2024 21:10:45 -0600 Subject: [PATCH 14/36] update inferers Signed-off-by: Dong Yang --- monai/apps/generation/maisi/inferers/__init__.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/monai/apps/generation/maisi/inferers/__init__.py b/monai/apps/generation/maisi/inferers/__init__.py index 1e97f89407..7b567b0311 100644 --- a/monai/apps/generation/maisi/inferers/__init__.py +++ b/monai/apps/generation/maisi/inferers/__init__.py @@ -8,3 +8,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations + +from .inferer import ( + DiffusionInfererMaisi, + LatentDiffusionInfererMaisi +) From f6e5d2bf600feba1cc5f0a63fa907fdae98f24a8 Mon Sep 17 00:00:00 2001 From: Dong Yang Date: Fri, 21 Jun 2024 21:43:59 -0600 Subject: [PATCH 15/36] add unit test cases Signed-off-by: Dong Yang --- tests/test_diffusion_inferer_maisi.py | 265 ++++++++ tests/test_diffusion_model_unet_maisi.py | 619 +++++++++++++++++++ tests/test_latent_diffusion_inferer_maisi.py | 366 +++++++++++ 3 files changed, 1250 insertions(+) create mode 100644 tests/test_diffusion_inferer_maisi.py create mode 100644 tests/test_diffusion_model_unet_maisi.py create mode 100644 tests/test_latent_diffusion_inferer_maisi.py diff --git a/tests/test_diffusion_inferer_maisi.py b/tests/test_diffusion_inferer_maisi.py new file mode 100644 index 0000000000..764ee015a4 --- /dev/null +++ b/tests/test_diffusion_inferer_maisi.py @@ -0,0 +1,265 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.inferers import DiffusionInfererMaisi +from monai.networks.nets import DiffusionModelUNet +from monai.networks.schedulers import DDIMScheduler, DDPMScheduler +from monai.utils import optional_import + +_, has_scipy = optional_import("scipy") +_, has_einops = optional_import("einops") + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8, 8), + ], +] + +class TestDiffusionInfererMaisi(unittest.TestCase): + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_call(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInfererMaisi(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + top_region_index_tensor = torch.randn(input_shape).to(device) + bottom_region_index_tensor = torch.randn(input_shape).to(device) + spacing_tensor = torch.randn(input_shape).to(device) + sample = inferer( + inputs=input, + noise=noise, + diffusion_model=model, + timesteps=timesteps, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sample_intermediates(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInfererMaisi(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + top_region_index_tensor = torch.randn(input_shape).to(device) + bottom_region_index_tensor = torch.randn(input_shape).to(device) + spacing_tensor = torch.randn(input_shape).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_ddpm_sampler(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=1000) + inferer = DiffusionInfererMaisi(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + top_region_index_tensor = torch.randn(input_shape).to(device) + bottom_region_index_tensor = torch.randn(input_shape).to(device) + spacing_tensor = torch.randn(input_shape).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_ddim_sampler(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInfererMaisi(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + top_region_index_tensor = torch.randn(input_shape).to(device) + bottom_region_index_tensor = torch.randn(input_shape).to(device) + spacing_tensor = torch.randn(input_shape).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sampler_conditioned(self, model_params, input_shape): + model_params["with_conditioning"] = True + model_params["cross_attention_dim"] = 3 + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInfererMaisi(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + conditioning = torch.randn([input_shape[0], 1, 3]).to(device) + top_region_index_tensor = torch.randn(input_shape).to(device) + bottom_region_index_tensor = torch.randn(input_shape).to(device) + spacing_tensor = torch.randn(input_shape).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sampler_conditioned_concat(self, model_params, input_shape): + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInfererMaisi(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + top_region_index_tensor = torch.randn(input_shape).to(device) + bottom_region_index_tensor = torch.randn(input_shape).to(device) + spacing_tensor = torch.randn(input_shape).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + mode="concat", + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_call_conditioned_concat(self, model_params, input_shape): + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInfererMaisi(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + top_region_index_tensor = torch.randn(input_shape).to(device) + bottom_region_index_tensor = torch.randn(input_shape).to(device) + spacing_tensor = torch.randn(input_shape).to(device) + sample = inferer( + inputs=input, + noise=noise, + diffusion_model=model, + timesteps=timesteps, + condition=conditioning, + mode="concat", + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor + ) + self.assertEqual(sample.shape, input_shape) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_diffusion_model_unet_maisi.py b/tests/test_diffusion_model_unet_maisi.py new file mode 100644 index 0000000000..d9bcb0f938 --- /dev/null +++ b/tests/test_diffusion_model_unet_maisi.py @@ -0,0 +1,619 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import tempfile +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.apps import download_url +from monai.networks import eval_mode +from monai.networks.nets import DiffusionModelUNet +from monai.utils import optional_import +from tests.utils import skip_if_downloading_fails, testing_data_config + +_, has_einops = optional_import("einops") + +UNCOND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": (1, 1, 2), + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, True, True), + "num_head_channels": (0, 2, 4), + "norm_num_groups": 8, + } + ], +] + +UNCOND_CASES_3D = [ + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": (0, 0, 4), + "norm_num_groups": 8, + } + ], +] + +COND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "upcast_attention": True, + } + ], +] + +DROPOUT_OK = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "dropout_cattn": 0.25, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + } + ], +] + +DROPOUT_WRONG = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "dropout_cattn": 3.0, + } + ] +] + + +class TestDiffusionModelUNetMaisi2D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_2D) + @skipUnless(has_einops, "Requires einops") + def test_shape_unconditioned_models(self, input_param): + net = DiffusionModelUNetMaisi(**input_param) + with eval_mode(net): + result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + @skipUnless(has_einops, "Requires einops") + def test_timestep_with_wrong_shape(self): + net = DiffusionModelUNetMaisi( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long()) + + @skipUnless(has_einops, "Requires einops") + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = DiffusionModelUNetMaisi( + spatial_dims=2, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with eval_mode(net): + result = net.forward(torch.rand((1, in_channels, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, out_channels, 16, 16)) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + DiffusionModelUNetMaisi( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 12), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + def test_attention_levels_with_different_length_num_head_channels(self): + with self.assertRaises(ValueError): + DiffusionModelUNetMaisi( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + num_head_channels=(0, 2), + norm_num_groups=8, + ) + + def test_num_res_blocks_with_different_length_channels(self): + with self.assertRaises(ValueError): + DiffusionModelUNetMaisi( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=(1, 1), + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + @skipUnless(has_einops, "Requires einops") + def test_shape_conditioned_models(self): + net = DiffusionModelUNetMaisi( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + norm_num_groups=8, + num_head_channels=8, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + def test_with_conditioning_cross_attention_dim_none(self): + with self.assertRaises(ValueError): + DiffusionModelUNetMaisi( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=None, + norm_num_groups=8, + ) + + @skipUnless(has_einops, "Requires einops") + def test_context_with_conditioning_none(self): + net = DiffusionModelUNetMaisi( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=False, + transformer_num_layers=1, + norm_num_groups=8, + ) + + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + context=torch.rand((1, 1, 3)), + ) + + @skipUnless(has_einops, "Requires einops") + def test_shape_conditioned_models_class_conditioning(self): + net = DiffusionModelUNetMaisi( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + class_labels=torch.randint(0, 2, (1,)).long(), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + @skipUnless(has_einops, "Requires einops") + def test_conditioned_models_no_class_labels(self): + net = DiffusionModelUNetMaisi( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + with self.assertRaises(ValueError): + net.forward(x=torch.rand((1, 1, 16, 32)), timesteps=torch.randint(0, 1000, (1,)).long()) + + @skipUnless(has_einops, "Requires einops") + def test_model_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + DiffusionModelUNetMaisi( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + @parameterized.expand(COND_CASES_2D) + @skipUnless(has_einops, "Requires einops") + def test_conditioned_2d_models_shape(self, input_param): + net = DiffusionModelUNetMaisi(**input_param) + with eval_mode(net): + result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3))) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + @parameterized.expand(UNCOND_CASES_2D) + @skipUnless(has_einops, "Requires einops") + def test_shape_with_additional_inputs(self, input_param): + input_param["input_top_region_index"] = True + input_param["input_bottom_region_index"] = True + input_param["input_spacing"] = True + net = DiffusionModelUNetMaisi(**input_param) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 16)), + timesteps=torch.randint(0, 1000, (1,)).long(), + top_region_index_tensor=torch.rand((1, 4)), + bottom_region_index_tensor=torch.rand((1, 4)), + spacing_tensor=torch.rand((1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + +class TestDiffusionModelUNetMaisi3D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_3D) + @skipUnless(has_einops, "Requires einops") + def test_shape_unconditioned_models(self, input_param): + net = DiffusionModelUNetMaisi(**input_param) + with eval_mode(net): + result = net.forward(torch.rand((1, 1, 16, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + @skipUnless(has_einops, "Requires einops") + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = DiffusionModelUNetMaisi( + spatial_dims=3, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=4, + ) + with eval_mode(net): + result = net.forward(torch.rand((1, in_channels, 16, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, out_channels, 16, 16, 16)) + + @skipUnless(has_einops, "Requires einops") + def test_shape_conditioned_models(self): + net = DiffusionModelUNetMaisi( + spatial_dims=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(16, 16, 16), + attention_levels=(False, False, True), + norm_num_groups=16, + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 16, 16)), + timesteps=torch.randint(0, 1000, (1,)).long(), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + # Test dropout specification for cross-attention blocks + @parameterized.expand(DROPOUT_WRONG) + def test_wrong_dropout(self, input_param): + with self.assertRaises(ValueError): + _ = DiffusionModelUNetMaisi(**input_param) + + @parameterized.expand(DROPOUT_OK) + @skipUnless(has_einops, "Requires einops") + def test_right_dropout(self, input_param): + _ = DiffusionModelUNetMaisi(**input_param) + + @skipUnless(has_einops, "Requires einops") + def test_compatibility_with_monai_generative(self): + # test loading weights from a model saved in MONAI Generative, version 0.2.3 + with skip_if_downloading_fails(): + net = DiffusionModelUNetMaisi( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + cross_attention_dim=3, + transformer_num_layers=1, + norm_num_groups=8, + ) + + tmpdir = tempfile.mkdtemp() + key = "diffusion_model_unet_monai_generative_weights" + url = testing_data_config("models", key, "url") + hash_type = testing_data_config("models", key, "hash_type") + hash_val = testing_data_config("models", key, "hash_val") + filename = "diffusion_model_unet_monai_generative_weights.pt" + + weight_path = os.path.join(tmpdir, filename) + download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) + + net.load_old_state_dict(torch.load(weight_path), verbose=False) + + @parameterized.expand(UNCOND_CASES_3D) + @skipUnless(has_einops, "Requires einops") + def test_shape_with_additional_inputs(self, input_param): + input_param["input_top_region_index"] = True + input_param["input_bottom_region_index"] = True + input_param["input_spacing"] = True + net = DiffusionModelUNetMaisi(**input_param) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 16, 16)), + timesteps=torch.randint(0, 1000, (1,)).long(), + top_region_index_tensor=torch.rand((1, 4)), + bottom_region_index_tensor=torch.rand((1, 4)), + spacing_tensor=torch.rand((1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_latent_diffusion_inferer_maisi.py b/tests/test_latent_diffusion_inferer_maisi.py new file mode 100644 index 0000000000..1a1c0cff3d --- /dev/null +++ b/tests/test_latent_diffusion_inferer_maisi.py @@ -0,0 +1,366 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.inferers import LatentDiffusionInfererMaisi +from monai.networks.nets import AutoencoderKL, DiffusionModelUNet +from monai.networks.schedulers import DDPMScheduler +from monai.utils import optional_import + +_, has_einops = optional_import("einops") +TEST_CASES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], +] + + +class TestLatentDiffusionInfererMaisi(unittest.TestCase): + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_prediction_shape( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + else: + raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") + + stage_2 = DiffusionModelUNet(**stage_2_params) if dm_model_type == "DiffusionModelUNet" else None + if stage_2 is None: + raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + top_region_index_tensor = torch.randint(0, 2, input_shape).to(device) + bottom_region_index_tensor = torch.randint(0, 2, input_shape).to(device) + spacing_tensor = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sample_shape( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + else: + raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") + + stage_2 = DiffusionModelUNet(**stage_2_params) if dm_model_type == "DiffusionModelUNet" else None + if stage_2 is None: + raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + top_region_index_tensor = torch.randint(0, 2, input_shape).to(device) + bottom_region_index_tensor = torch.randint(0, 2, input_shape).to(device) + spacing_tensor = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sample_intermediates( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + else: + raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") + + stage_2 = DiffusionModelUNet(**stage_2_params) if dm_model_type == "DiffusionModelUNet" else None + if stage_2 is None: + raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + top_region_index_tensor = torch.randint(0, 2, input_shape).to(device) + bottom_region_index_tensor = torch.randint(0, 2, input_shape).to(device) + spacing_tensor = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + sample, intermediates = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, input_shape) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_get_likelihoods( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + else: + raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") + + stage_2 = DiffusionModelUNet(**stage_2_params) if dm_model_type == "DiffusionModelUNet" else None + if stage_2 is None: + raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + top_region_index_tensor = torch.randint(0, 2, input_shape).to(device) + bottom_region_index_tensor = torch.randint(0, 2, input_shape).to(device) + spacing_tensor = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, latent_shape) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_resample_likelihoods( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + else: + raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") + + stage_2 = DiffusionModelUNet(**stage_2_params) if dm_model_type == "DiffusionModelUNet" else None + if stage_2 is None: + raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + top_region_index_tensor = torch.randint(0, 2, input_shape).to(device) + bottom_region_index_tensor = torch.randint(0, 2, input_shape).to(device) + spacing_tensor = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + resample_latent_likelihoods=True, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_prediction_shape_conditioned_concat( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + else: + raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") + + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + stage_2 = DiffusionModelUNet(**stage_2_params) if dm_model_type == "DiffusionModelUNet" else None + if stage_2 is None: + raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + top_region_index_tensor = torch.randint(0, 2, input_shape).to(device) + bottom_region_index_tensor = torch.randint(0, 2, input_shape).to(device) + spacing_tensor = torch.randn(input_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + condition=conditioning, + mode="concat", + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sample_shape_conditioned_concat( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + else: + raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") + + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + stage_2 = DiffusionModelUNet(**stage_2_params) if dm_model_type == "DiffusionModelUNet" else None + if stage_2 is None: + raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + top_region_index_tensor = torch.randint(0, 2, input_shape).to(device) + bottom_region_index_tensor = torch.randint(0, 2, input_shape).to(device) + spacing_tensor = torch.randn(input_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + ) + self.assertEqual(sample.shape, input_shape) + + +if __name__ == "__main__": + unittest.main() From 346c645990e2e413cc455060e2a275c0a65fc0e9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 22 Jun 2024 03:44:27 +0000 Subject: [PATCH 16/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_diffusion_model_unet_maisi.py | 1 - tests/test_latent_diffusion_inferer_maisi.py | 14 +++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/test_diffusion_model_unet_maisi.py b/tests/test_diffusion_model_unet_maisi.py index d9bcb0f938..78ec7094c4 100644 --- a/tests/test_diffusion_model_unet_maisi.py +++ b/tests/test_diffusion_model_unet_maisi.py @@ -21,7 +21,6 @@ from monai.apps import download_url from monai.networks import eval_mode -from monai.networks.nets import DiffusionModelUNet from monai.utils import optional_import from tests.utils import skip_if_downloading_fails, testing_data_config diff --git a/tests/test_latent_diffusion_inferer_maisi.py b/tests/test_latent_diffusion_inferer_maisi.py index 1a1c0cff3d..9a84cf736d 100644 --- a/tests/test_latent_diffusion_inferer_maisi.py +++ b/tests/test_latent_diffusion_inferer_maisi.py @@ -65,7 +65,7 @@ def test_prediction_shape( stage_1 = AutoencoderKL(**autoencoder_params) else: raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") - + stage_2 = DiffusionModelUNet(**stage_2_params) if dm_model_type == "DiffusionModelUNet" else None if stage_2 is None: raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") @@ -107,7 +107,7 @@ def test_sample_shape( stage_1 = AutoencoderKL(**autoencoder_params) else: raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") - + stage_2 = DiffusionModelUNet(**stage_2_params) if dm_model_type == "DiffusionModelUNet" else None if stage_2 is None: raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") @@ -146,7 +146,7 @@ def test_sample_intermediates( stage_1 = AutoencoderKL(**autoencoder_params) else: raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") - + stage_2 = DiffusionModelUNet(**stage_2_params) if dm_model_type == "DiffusionModelUNet" else None if stage_2 is None: raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") @@ -188,7 +188,7 @@ def test_get_likelihoods( stage_1 = AutoencoderKL(**autoencoder_params) else: raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") - + stage_2 = DiffusionModelUNet(**stage_2_params) if dm_model_type == "DiffusionModelUNet" else None if stage_2 is None: raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") @@ -229,7 +229,7 @@ def test_resample_likelihoods( stage_1 = AutoencoderKL(**autoencoder_params) else: raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") - + stage_2 = DiffusionModelUNet(**stage_2_params) if dm_model_type == "DiffusionModelUNet" else None if stage_2 is None: raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") @@ -271,7 +271,7 @@ def test_prediction_shape_conditioned_concat( stage_1 = AutoencoderKL(**autoencoder_params) else: raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") - + stage_2_params = stage_2_params.copy() n_concat_channel = 3 stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel @@ -322,7 +322,7 @@ def test_sample_shape_conditioned_concat( stage_1 = AutoencoderKL(**autoencoder_params) else: raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") - + stage_2_params = stage_2_params.copy() n_concat_channel = 3 stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel From d035b73e058640c33461e269f4761b040a9933b0 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Sat, 22 Jun 2024 04:43:25 +0000 Subject: [PATCH 17/36] fix format error Signed-off-by: dongyang0122 --- .../generation/maisi/inferers/__init__.py | 5 +-- .../maisi/inferers/inferer_maisi.py | 17 +++------ .../networks/diffusion_model_unet_maisi.py | 33 +++++++++-------- tests/test_diffusion_inferer_maisi.py | 36 ++++++++++--------- tests/test_diffusion_model_unet_maisi.py | 34 +----------------- tests/test_latent_diffusion_inferer_maisi.py | 23 ++++++------ 6 files changed, 55 insertions(+), 93 deletions(-) diff --git a/monai/apps/generation/maisi/inferers/__init__.py b/monai/apps/generation/maisi/inferers/__init__.py index 7b567b0311..5bdca05bd0 100644 --- a/monai/apps/generation/maisi/inferers/__init__.py +++ b/monai/apps/generation/maisi/inferers/__init__.py @@ -11,7 +11,4 @@ from __future__ import annotations -from .inferer import ( - DiffusionInfererMaisi, - LatentDiffusionInfererMaisi -) +from .inferer_maisi import DiffusionInfererMaisi, LatentDiffusionInfererMaisi diff --git a/monai/apps/generation/maisi/inferers/inferer_maisi.py b/monai/apps/generation/maisi/inferers/inferer_maisi.py index a7d0915510..dce5491319 100644 --- a/monai/apps/generation/maisi/inferers/inferer_maisi.py +++ b/monai/apps/generation/maisi/inferers/inferer_maisi.py @@ -16,14 +16,14 @@ import torch import torch.nn as nn +from generative.inferers.inferer import DiffusionInferer +from generative.networks.nets import VQVAE, SPADEDiffusionModelUNet + from monai.data import decollate_batch from monai.utils import optional_import tqdm, has_tqdm = optional_import("tqdm", name="tqdm") -from generative.inferers.inferer import DiffusionInferer -from generative.networks.nets import VQVAE, SPADEDiffusionModelUNet - class DiffusionInfererMaisi(DiffusionInferer): """ @@ -66,9 +66,7 @@ def __call__( if mode not in ["crossattn", "concat"]: raise NotImplementedError(f"{mode} condition is not supported") - noisy_image = self.scheduler.add_noise( - original_samples=inputs, noise=noise, timesteps=timesteps - ) + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) if mode == "concat": noisy_image = torch.cat([noisy_image, condition], dim=1) condition = None @@ -158,7 +156,6 @@ def sample( return image - class LatentDiffusionInfererMaisi(DiffusionInfererMaisi): """ LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can @@ -268,11 +265,7 @@ def sample( if save_intermediates: intermediates = [] for latent_intermediate in latent_intermediates: - intermediates.append( - autoencoder_model.decode_stage_2_outputs( - latent_intermediate / self.scale_factor - ) - ) + intermediates.append(autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor)) return image, intermediates else: diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index 1909118ba5..c2ab7c7e4e 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -38,8 +38,7 @@ __all__ = ["DiffusionModelUNetMaisi"] -from monai.networks.nets.diffusion_model_unet import DiffusionModelUNet -from generative.networks.nets.diffusion_model_unet import get_timestep_embedding +from generative.networks.nets.diffusion_model_unet import DiffusionModelUNet, get_timestep_embedding class DiffusionModelUNetMaisi(DiffusionModelUNet): @@ -122,30 +121,22 @@ def __init__( new_time_embed_dim = time_embed_dim if self.input_top_region_index: self.top_region_index_layer = nn.Sequential( - nn.Linear(4, time_embed_dim), - nn.SiLU(), - nn.Linear(time_embed_dim, time_embed_dim), + nn.Linear(4, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) ) new_time_embed_dim += time_embed_dim if self.input_bottom_region_index: self.bottom_region_index_layer = nn.Sequential( - nn.Linear(4, time_embed_dim), - nn.SiLU(), - nn.Linear(time_embed_dim, time_embed_dim), + nn.Linear(4, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) ) new_time_embed_dim += time_embed_dim if self.input_spacing: self.spacing_layer = nn.Sequential( - nn.Linear(3, time_embed_dim), - nn.SiLU(), - nn.Linear(time_embed_dim, time_embed_dim), + nn.Linear(3, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) ) new_time_embed_dim += time_embed_dim self.time_embed = nn.Sequential( - nn.Linear(num_channels[0], time_embed_dim), - nn.SiLU(), - nn.Linear(time_embed_dim, new_time_embed_dim), + nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, new_time_embed_dim) ) def forward( @@ -198,12 +189,20 @@ def forward( down_block_res_samples.extend(res_samples) if down_block_additional_residuals is not None: - down_block_res_samples = [res + add_res for res, add_res in zip(down_block_res_samples, down_block_additional_residuals)] + down_block_res_samples = [ + res + add_res for res, add_res in zip(down_block_res_samples, down_block_additional_residuals) + ] h = self.middle_block(hidden_states=h, temb=emb, context=context) if mid_block_additional_residual is not None: h = h + mid_block_additional_residual for upsample_block in self.up_blocks: - res_samples = down_block_res_samples[-len(upsample_block.resnets):] - down_block_res_samples = down_block_res_samples[:-len(upsample_block.resnets + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) + + # 7. output block + h = self.out(h) + + return h diff --git a/tests/test_diffusion_inferer_maisi.py b/tests/test_diffusion_inferer_maisi.py index 764ee015a4..ee5d1ec728 100644 --- a/tests/test_diffusion_inferer_maisi.py +++ b/tests/test_diffusion_inferer_maisi.py @@ -9,14 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from unittest import skipUnless import torch from parameterized import parameterized -from monai.inferers import DiffusionInfererMaisi -from monai.networks.nets import DiffusionModelUNet +from monai.apps.generation.maisi.inferers.inferer_maisi import DiffusionInfererMaisi +from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi from monai.networks.schedulers import DDIMScheduler, DDPMScheduler from monai.utils import optional_import @@ -52,11 +54,12 @@ ], ] + class TestDiffusionInfererMaisi(unittest.TestCase): @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_call(self, model_params, input_shape): - model = DiffusionModelUNet(**model_params) + model = DiffusionModelUNetMaisi(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() @@ -76,14 +79,14 @@ def test_call(self, model_params, input_shape): timesteps=timesteps, top_region_index_tensor=top_region_index_tensor, bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor + spacing_tensor=spacing_tensor, ) self.assertEqual(sample.shape, input_shape) @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_sample_intermediates(self, model_params, input_shape): - model = DiffusionModelUNet(**model_params) + model = DiffusionModelUNetMaisi(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() @@ -102,14 +105,14 @@ def test_sample_intermediates(self, model_params, input_shape): intermediate_steps=1, top_region_index_tensor=top_region_index_tensor, bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor + spacing_tensor=spacing_tensor, ) self.assertEqual(len(intermediates), 10) @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_ddpm_sampler(self, model_params, input_shape): - model = DiffusionModelUNet(**model_params) + model = DiffusionModelUNetMaisi(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() @@ -128,14 +131,14 @@ def test_ddpm_sampler(self, model_params, input_shape): intermediate_steps=1, top_region_index_tensor=top_region_index_tensor, bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor + spacing_tensor=spacing_tensor, ) self.assertEqual(len(intermediates), 10) @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_ddim_sampler(self, model_params, input_shape): - model = DiffusionModelUNet(**model_params) + model = DiffusionModelUNetMaisi(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() @@ -154,7 +157,7 @@ def test_ddim_sampler(self, model_params, input_shape): intermediate_steps=1, top_region_index_tensor=top_region_index_tensor, bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor + spacing_tensor=spacing_tensor, ) self.assertEqual(len(intermediates), 10) @@ -163,7 +166,7 @@ def test_ddim_sampler(self, model_params, input_shape): def test_sampler_conditioned(self, model_params, input_shape): model_params["with_conditioning"] = True model_params["cross_attention_dim"] = 3 - model = DiffusionModelUNet(**model_params) + model = DiffusionModelUNetMaisi(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() @@ -184,7 +187,7 @@ def test_sampler_conditioned(self, model_params, input_shape): conditioning=conditioning, top_region_index_tensor=top_region_index_tensor, bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor + spacing_tensor=spacing_tensor, ) self.assertEqual(len(intermediates), 10) @@ -196,7 +199,7 @@ def test_sampler_conditioned_concat(self, model_params, input_shape): model_params["in_channels"] = model_params["in_channels"] + n_concat_channel model_params["cross_attention_dim"] = None model_params["with_conditioning"] = False - model = DiffusionModelUNet(**model_params) + model = DiffusionModelUNetMaisi(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() @@ -220,7 +223,7 @@ def test_sampler_conditioned_concat(self, model_params, input_shape): mode="concat", top_region_index_tensor=top_region_index_tensor, bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor + spacing_tensor=spacing_tensor, ) self.assertEqual(len(intermediates), 10) @@ -232,7 +235,7 @@ def test_call_conditioned_concat(self, model_params, input_shape): model_params["in_channels"] = model_params["in_channels"] + n_concat_channel model_params["cross_attention_dim"] = None model_params["with_conditioning"] = False - model = DiffusionModelUNet(**model_params) + model = DiffusionModelUNetMaisi(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() @@ -257,9 +260,10 @@ def test_call_conditioned_concat(self, model_params, input_shape): mode="concat", top_region_index_tensor=top_region_index_tensor, bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor + spacing_tensor=spacing_tensor, ) self.assertEqual(sample.shape, input_shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_diffusion_model_unet_maisi.py b/tests/test_diffusion_model_unet_maisi.py index 78ec7094c4..069bb68879 100644 --- a/tests/test_diffusion_model_unet_maisi.py +++ b/tests/test_diffusion_model_unet_maisi.py @@ -11,18 +11,15 @@ from __future__ import annotations -import os -import tempfile import unittest from unittest import skipUnless import torch from parameterized import parameterized -from monai.apps import download_url +from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi from monai.networks import eval_mode from monai.utils import optional_import -from tests.utils import skip_if_downloading_fails, testing_data_config _, has_einops = optional_import("einops") @@ -567,35 +564,6 @@ def test_wrong_dropout(self, input_param): def test_right_dropout(self, input_param): _ = DiffusionModelUNetMaisi(**input_param) - @skipUnless(has_einops, "Requires einops") - def test_compatibility_with_monai_generative(self): - # test loading weights from a model saved in MONAI Generative, version 0.2.3 - with skip_if_downloading_fails(): - net = DiffusionModelUNetMaisi( - spatial_dims=2, - in_channels=1, - out_channels=1, - num_res_blocks=1, - num_channels=(8, 8, 8), - attention_levels=(False, False, True), - with_conditioning=True, - cross_attention_dim=3, - transformer_num_layers=1, - norm_num_groups=8, - ) - - tmpdir = tempfile.mkdtemp() - key = "diffusion_model_unet_monai_generative_weights" - url = testing_data_config("models", key, "url") - hash_type = testing_data_config("models", key, "hash_type") - hash_val = testing_data_config("models", key, "hash_val") - filename = "diffusion_model_unet_monai_generative_weights.pt" - - weight_path = os.path.join(tmpdir, filename) - download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) - - net.load_old_state_dict(torch.load(weight_path), verbose=False) - @parameterized.expand(UNCOND_CASES_3D) @skipUnless(has_einops, "Requires einops") def test_shape_with_additional_inputs(self, input_param): diff --git a/tests/test_latent_diffusion_inferer_maisi.py b/tests/test_latent_diffusion_inferer_maisi.py index 9a84cf736d..3495da3cda 100644 --- a/tests/test_latent_diffusion_inferer_maisi.py +++ b/tests/test_latent_diffusion_inferer_maisi.py @@ -15,10 +15,11 @@ from unittest import skipUnless import torch +from generative.networks.nets.autoencoderkl import AutoencoderKL from parameterized import parameterized -from monai.inferers import LatentDiffusionInfererMaisi -from monai.networks.nets import AutoencoderKL, DiffusionModelUNet +from monai.apps.generation.maisi.inferers.inferer_maisi import LatentDiffusionInfererMaisi +from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi from monai.networks.schedulers import DDPMScheduler from monai.utils import optional_import @@ -38,7 +39,7 @@ "with_decoder_nonlocal_attn": False, "norm_num_groups": 4, }, - "DiffusionModelUNet", + "DiffusionModelUNetMaisi", { "spatial_dims": 2, "in_channels": 3, @@ -51,7 +52,7 @@ }, (1, 1, 8, 8), (1, 3, 4, 4), - ], + ] ] @@ -66,7 +67,7 @@ def test_prediction_shape( else: raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") - stage_2 = DiffusionModelUNet(**stage_2_params) if dm_model_type == "DiffusionModelUNet" else None + stage_2 = DiffusionModelUNetMaisi(**stage_2_params) if dm_model_type == "DiffusionModelUNetMaisi" else None if stage_2 is None: raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") @@ -108,7 +109,7 @@ def test_sample_shape( else: raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") - stage_2 = DiffusionModelUNet(**stage_2_params) if dm_model_type == "DiffusionModelUNet" else None + stage_2 = DiffusionModelUNetMaisi(**stage_2_params) if dm_model_type == "DiffusionModelUNetMaisi" else None if stage_2 is None: raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") @@ -147,7 +148,7 @@ def test_sample_intermediates( else: raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") - stage_2 = DiffusionModelUNet(**stage_2_params) if dm_model_type == "DiffusionModelUNet" else None + stage_2 = DiffusionModelUNetMaisi(**stage_2_params) if dm_model_type == "DiffusionModelUNetMaisi" else None if stage_2 is None: raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") @@ -189,7 +190,7 @@ def test_get_likelihoods( else: raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") - stage_2 = DiffusionModelUNet(**stage_2_params) if dm_model_type == "DiffusionModelUNet" else None + stage_2 = DiffusionModelUNetMaisi(**stage_2_params) if dm_model_type == "DiffusionModelUNetMaisi" else None if stage_2 is None: raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") @@ -230,7 +231,7 @@ def test_resample_likelihoods( else: raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") - stage_2 = DiffusionModelUNet(**stage_2_params) if dm_model_type == "DiffusionModelUNet" else None + stage_2 = DiffusionModelUNetMaisi(**stage_2_params) if dm_model_type == "DiffusionModelUNetMaisi" else None if stage_2 is None: raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") @@ -275,7 +276,7 @@ def test_prediction_shape_conditioned_concat( stage_2_params = stage_2_params.copy() n_concat_channel = 3 stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel - stage_2 = DiffusionModelUNet(**stage_2_params) if dm_model_type == "DiffusionModelUNet" else None + stage_2 = DiffusionModelUNetMaisi(**stage_2_params) if dm_model_type == "DiffusionModelUNetMaisi" else None if stage_2 is None: raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") @@ -326,7 +327,7 @@ def test_sample_shape_conditioned_concat( stage_2_params = stage_2_params.copy() n_concat_channel = 3 stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel - stage_2 = DiffusionModelUNet(**stage_2_params) if dm_model_type == "DiffusionModelUNet" else None + stage_2 = DiffusionModelUNetMaisi(**stage_2_params) if dm_model_type == "DiffusionModelUNetMaisi" else None if stage_2 is None: raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") From f21d77c5a00a8dc42a58029e933f4e89f0a27d71 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Sat, 22 Jun 2024 05:05:11 +0000 Subject: [PATCH 18/36] update diffusion unet Signed-off-by: dongyang0122 --- .../networks/diffusion_model_unet_maisi.py | 1934 ++++++++++++++++- 1 file changed, 1863 insertions(+), 71 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index c2ab7c7e4e..e8aa1fe695 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -31,112 +31,1869 @@ from __future__ import annotations +import importlib.util +import math from collections.abc import Sequence import torch +import torch.nn.functional as F from torch import nn -__all__ = ["DiffusionModelUNetMaisi"] +from monai.networks.blocks import Convolution, MLPBlock +from monai.networks.layers.factories import Pool +from monai.utils import ensure_tuple_rep -from generative.networks.nets.diffusion_model_unet import DiffusionModelUNet, get_timestep_embedding +# To install xformers, use pip install xformers==0.0.16rc401 +if importlib.util.find_spec("xformers") is not None: + import xformers + import xformers.ops + has_xformers = True +else: + xformers = None + has_xformers = False -class DiffusionModelUNetMaisi(DiffusionModelUNet): + +# TODO: Use MONAI's optional_import +# from monai.utils import optional_import +# xformers, has_xformers = optional_import("xformers.ops", name="xformers") + +__all__ = ["CustomDiffusionModelUNet"] + + +def zero_module(module: nn.Module) -> nn.Module: + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class CrossAttention(nn.Module): + """ + A cross attention layer. + + Args: + query_dim: number of channels in the query. + cross_attention_dim: number of channels in the context. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each head. + dropout: dropout probability to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: int | None = None, + num_attention_heads: int = 8, + num_head_channels: int = 64, + dropout: float = 0.0, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.use_flash_attention = use_flash_attention + inner_dim = num_head_channels * num_attention_heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + + self.scale = 1 / math.sqrt(num_head_channels) + self.num_heads = num_attention_heads + + self.upcast_attention = upcast_attention + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + """ + Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. + """ + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + """Combine the output of the attention heads back into the hidden state dimension.""" + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_probs.to(dtype=dtype) + + x = torch.bmm(attention_probs, value) + return x + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + query = self.to_q(x) + context = context if context is not None else x + key = self.to_k(context) + value = self.to_v(context) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if self.use_flash_attention: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) + + return self.to_out(x) + + +class BasicTransformerBlock(nn.Module): + """ + A basic Transformer block. + + Args: + num_channels: number of channels in the input and output. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each attention head. + dropout: dropout probability to use. + cross_attention_dim: size of the context vector for cross attention. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + num_channels: int, + num_attention_heads: int, + num_head_channels: int, + dropout: float = 0.0, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.attn1 = CrossAttention( + query_dim=num_channels, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) # is a self-attention + self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) + self.attn2 = CrossAttention( + query_dim=num_channels, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) # is a self-attention if context is None + self.norm1 = nn.LayerNorm(num_channels) + self.norm2 = nn.LayerNorm(num_channels) + self.norm3 = nn.LayerNorm(num_channels) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + # 1. Self-Attention + x = self.attn1(self.norm1(x)) + x + + # 2. Cross-Attention + x = self.attn2(self.norm2(x), context=context) + x + + # 3. Feed-forward + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply + standard transformer action. Finally, reshape to image. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of channels in the input and output. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each attention head. + num_layers: number of layers of Transformer blocks to use. + dropout: dropout probability to use. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_attention_heads: int, + num_head_channels: int, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + inner_dim = num_attention_heads * num_head_channels + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + + self.proj_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=inner_dim, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + num_channels=inner_dim, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + for _ in range(num_layers) + ] + ) + + self.proj_out = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=inner_dim, + out_channels=in_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + ) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + # note: if no context is given, cross-attention defaults to self-attention + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + residual = x + x = self.norm(x) + x = self.proj_in(x) + + inner_dim = x.shape[1] + + if self.spatial_dims == 2: + x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + if self.spatial_dims == 3: + x = x.permute(0, 2, 3, 4, 1).reshape(batch, height * width * depth, inner_dim) + + for block in self.transformer_blocks: + x = block(x, context=context) + + if self.spatial_dims == 2: + x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + if self.spatial_dims == 3: + x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous() + + x = self.proj_out(x) + return x + residual + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. Uses three q, k, v linear layers to + compute attention. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + num_head_channels: number of channels in each attention head. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon value to use for the normalisation. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: int, + num_head_channels: int | None = None, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.use_flash_attention = use_flash_attention + self.spatial_dims = spatial_dims + self.num_channels = num_channels + + self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 + self.scale = 1 / math.sqrt(num_channels / self.num_heads) + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) + + self.to_q = nn.Linear(num_channels, num_channels) + self.to_k = nn.Linear(num_channels, num_channels) + self.to_v = nn.Linear(num_channels, num_channels) + + self.proj_attn = nn.Linear(num_channels, num_channels) + + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + x = torch.bmm(attention_probs, value) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + # norm + x = self.norm(x) + + if self.spatial_dims == 2: + x = x.view(batch, channel, height * width).transpose(1, 2) + if self.spatial_dims == 3: + x = x.view(batch, channel, height * width * depth).transpose(1, 2) + + # proj to q, k, v + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if self.use_flash_attention: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) + + if self.spatial_dims == 2: + x = x.transpose(-1, -2).reshape(batch, channel, height, width) + if self.spatial_dims == 3: + x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) + + return x + residual + + +def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor: + """ + Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic + Models" https://arxiv.org/abs/2006.11239. + + Args: + timesteps: a 1-D Tensor of N indices, one per batch element. + embedding_dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + """ + # print(f'max_period: {max_period}; timesteps: {torch.norm(timesteps.float(), p=2)}; embedding_dim: {embedding_dim}') + + if timesteps.ndim != 1: + raise ValueError("Timesteps should be a 1d-array") + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + freqs = torch.exp(exponent / half_dim) + + args = timesteps[:, None].float() * freqs[None, :] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0)) + + return embedding + + +class Downsample(nn.Module): + """ + Downsampling layer. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is + False, the number of output channels must be the same as the number of input channels. + out_channels: number of output channels. + padding: controls the amount of implicit zero-paddings on both sides for padding number of points + for each dimension. + """ + + def __init__( + self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 + ) -> None: + super().__init__() + self.num_channels = num_channels + self.out_channels = out_channels or num_channels + self.use_conv = use_conv + if use_conv: + self.op = Convolution( + spatial_dims=spatial_dims, + in_channels=self.num_channels, + out_channels=self.out_channels, + strides=2, + kernel_size=3, + padding=padding, + conv_only=True, + ) + else: + if self.num_channels != self.out_channels: + raise ValueError("num_channels and out_channels must be equal when use_conv=False") + self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2) + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + del emb + if x.shape[1] != self.num_channels: + raise ValueError( + f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels " + f"({self.num_channels})" + ) + return self.op(x) + + +class Upsample(nn.Module): + """ + Upsampling layer with an optional convolution. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + use_conv: if True uses Convolution instead of Pool average to perform downsampling. + out_channels: number of output channels. + padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each + dimension. + """ + + def __init__( + self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 + ) -> None: + super().__init__() + self.num_channels = num_channels + self.out_channels = out_channels or num_channels + self.use_conv = use_conv + if use_conv: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=self.num_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=padding, + conv_only=True, + ) + else: + self.conv = None + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + del emb + if x.shape[1] != self.num_channels: + raise ValueError("Input channels should be equal to num_channels") + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # https://github.com/pytorch/pytorch/issues/86679 + dtype = x.dtype + if dtype == torch.bfloat16: + x = x.to(torch.float32) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + x = x.to(dtype) + + if self.use_conv: + x = self.conv(x) + return x + + +class ResnetBlock(nn.Module): + """ + Residual block with timestep conditioning. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + out_channels: number of output channels. + up: if True, performs upsampling. + down: if True, performs downsampling. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + out_channels: int | None = None, + up: bool = False, + down: bool = False, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = in_channels + self.emb_channels = temb_channels + self.out_channels = out_channels or in_channels + self.up = up + self.down = down + + self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + self.nonlinearity = nn.SiLU() + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.upsample = self.downsample = None + if self.up: + self.upsample = Upsample(spatial_dims, in_channels, use_conv=False) + elif down: + self.downsample = Downsample(spatial_dims, in_channels, use_conv=False) + + self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) + + self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=self.out_channels, eps=norm_eps, affine=True) + self.conv2 = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + if self.out_channels == in_channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = self.nonlinearity(h) + + if self.upsample is not None: + if h.shape[0] >= 64: + x = x.contiguous() + h = h.contiguous() + x = self.upsample(x) + h = self.upsample(h) + elif self.downsample is not None: + x = self.downsample(x) + h = self.downsample(h) + + h = self.conv1(h) + + if self.spatial_dims == 2: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] + else: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] + h = h + temb + + h = self.norm2(h) + h = self.nonlinearity(h) + h = self.conv2(h) + + return self.skip_connection(x) + h + + +class DownBlock(nn.Module): + """ + Unet's down block containing resnet and downsamplers blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + if resblock_updown: + self.downsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + del context + output_states = [] + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class AttnDownBlock(nn.Module): + """ + Unet's down block containing resnet, downsamplers and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + if resblock_updown: + self.downsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + del context + output_states = [] + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class CrossAttnDownBlock(nn.Module): + """ + Unet's down block containing resnet, downsamplers and cross-attention blocks. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + attentions.append( + SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + num_layers=transformer_num_layers, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + if resblock_updown: + self.downsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + output_states = [] + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=context) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class AttnMidBlock(nn.Module): + """ + Unet's mid block containing resnet and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.attention = None + + self.resnet_1 = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + self.attention = AttentionBlock( + spatial_dims=spatial_dims, + num_channels=in_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + + self.resnet_2 = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> torch.Tensor: + del context + hidden_states = self.resnet_1(hidden_states, temb) + hidden_states = self.attention(hidden_states) + hidden_states = self.resnet_2(hidden_states, temb) + + return hidden_states + + +class CrossAttnMidBlock(nn.Module): + """ + Unet's mid block containing resnet and cross-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + self.attention = None + + self.resnet_1 = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + self.attention = SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=in_channels, + num_attention_heads=in_channels // num_head_channels, + num_head_channels=num_head_channels, + num_layers=transformer_num_layers, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + self.resnet_2 = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> torch.Tensor: + hidden_states = self.resnet_1(hidden_states, temb) + hidden_states = self.attention(hidden_states, context=context) + hidden_states = self.resnet_2(hidden_states, temb) + + return hidden_states + + +class UpBlock(nn.Module): """ - DiffusionModelUNetMaisi extends the DiffusionModelUNet class to support additional - input features like region indices and spacing. This class is specifically designed - for enhanced image synthesis in medical applications. + Unet's up block containing resnet and upsamplers blocks. Args: - spatial_dims: Number of spatial dimensions (2 or 3). - in_channels: Number of input channels. - out_channels: Number of output channels. - num_res_blocks: Number of residual blocks per level. - num_channels: Tuple of block output channels. - attention_levels: List indicating which levels have attention. - norm_num_groups: Number of groups for group normalization. - norm_eps: Epsilon for group normalization. - resblock_updown: If True, use residual blocks for up/downsampling. - num_head_channels: Number of channels in each attention head. - with_conditioning: If True, add spatial transformers for conditioning. - transformer_num_layers: Number of layers of Transformer blocks to use. - cross_attention_dim: Number of context dimensions for cross-attention. - num_class_embeds: Number of class embeddings for class-conditional generation. - upcast_attention: If True, upcast attention operations to full precision. - use_flash_attention: If True, use flash attention for memory efficient attention. - dropout_cattn: Dropout value for cross-attention layers. - input_top_region_index: If True, include top region index in the input. - input_bottom_region_index: If True, include bottom region index in the input. - input_spacing: If True, include spacing information in the input. + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. """ def __init__( self, spatial_dims: int, in_channels: int, + prev_output_channel: int, out_channels: int, - num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), - num_channels: Sequence[int] = (32, 64, 64, 64), - attention_levels: Sequence[bool] = (False, False, True, True), + temb_channels: int, + num_res_blocks: int = 1, norm_num_groups: int = 32, norm_eps: float = 1e-6, + add_upsample: bool = True, resblock_updown: bool = False, - num_head_channels: int | Sequence[int] = 8, - with_conditioning: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + if resblock_updown: + self.upsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class AttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_upsample: + if resblock_updown: + self.upsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class CrossAttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, transformer_num_layers: int = 1, cross_attention_dim: int | None = None, - num_class_embeds: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, dropout_cattn: float = 0.0, - input_top_region_index: bool = False, - input_bottom_region_index: bool = False, - input_spacing: bool = False, ) -> None: - super().__init__( + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + if resblock_updown: + self.upsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=context) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +def get_down_block( + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_downsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_attn: + return AttnDownBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, + temb_channels=temb_channels, num_res_blocks=num_res_blocks, - num_channels=num_channels, - attention_levels=attention_levels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + elif with_cross_attn: + return CrossAttnDownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + else: + return DownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + ) + + +def get_mid_block( + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int, + norm_eps: float, + with_conditioning: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_conditioning: + return CrossAttnMidBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + else: + return AttnMidBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + + +def get_up_block( + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_upsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_attn: + return AttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + elif with_cross_attn: + return CrossAttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, - with_conditioning=with_conditioning, transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, - num_class_embeds=num_class_embeds, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) + else: + return UpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + ) + + +class DiffusionModelUNetMaisi(nn.Module): + """ + Unet network with timestep embedding and attention mechanisms for conditioning based on + Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 + and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see ResnetBlock) per level. + num_channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + input_top_region_index: bool = False, + input_bottom_region_index: bool = False, + input_spacing: bool = False, + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "CustomDiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "CustomDiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + ) + if dropout_cattn > 1.0 or dropout_cattn < 0.0: + raise ValueError("Dropout cannot be negative or >1.0!") + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + raise ValueError("CustomDiffusionModelUNet expects all num_channels being multiple of norm_num_groups") + + if len(num_channels) != len(attention_levels): + raise ValueError("CustomDiffusionModelUNet expects num_channels being same size of attention_levels") + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + + if len(num_res_blocks) != len(num_channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.in_channels = in_channels + self.block_out_channels = num_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = num_channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) self.input_top_region_index = input_top_region_index self.input_bottom_region_index = input_bottom_region_index self.input_spacing = input_spacing - time_embed_dim = num_channels[0] * 4 new_time_embed_dim = time_embed_dim if self.input_top_region_index: + # self.top_region_index_layer = nn.Linear(4, time_embed_dim) self.top_region_index_layer = nn.Sequential( nn.Linear(4, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) ) new_time_embed_dim += time_embed_dim if self.input_bottom_region_index: + # self.bottom_region_index_layer = nn.Linear(4, time_embed_dim) self.bottom_region_index_layer = nn.Sequential( nn.Linear(4, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) ) new_time_embed_dim += time_embed_dim if self.input_spacing: + # self.spacing_layer = nn.Linear(3, time_embed_dim) self.spacing_layer = nn.Sequential( nn.Linear(3, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) ) new_time_embed_dim += time_embed_dim - self.time_embed = nn.Sequential( - nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, new_time_embed_dim) + # down + self.down_blocks = nn.ModuleList([]) + output_channel = num_channels[0] + for i in range(len(num_channels)): + input_channel = output_channel + output_channel = num_channels[i] + is_final_block = i == len(num_channels) - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=new_time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + self.down_blocks.append(down_block) + + # mid + self.middle_block = get_mid_block( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + temb_channels=new_time_embed_dim, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + with_conditioning=with_conditioning, + num_head_channels=num_head_channels[-1], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + # up + self.up_blocks = nn.ModuleList([]) + reversed_block_out_channels = list(reversed(num_channels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_head_channels = list(reversed(num_head_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)] + + is_final_block = i == len(num_channels) - 1 + + up_block = get_up_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + prev_output_channel=prev_output_channel, + out_channels=output_channel, + temb_channels=new_time_embed_dim, + num_res_blocks=reversed_num_res_blocks[i] + 1, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(reversed_attention_levels[i] and not with_conditioning), + with_cross_attn=(reversed_attention_levels[i] and with_conditioning), + num_head_channels=reversed_num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + self.up_blocks.append(up_block) + + # out + self.out = nn.Sequential( + nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True), + nn.SiLU(), + zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=num_channels[0], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ), ) def forward( @@ -152,51 +1909,86 @@ def forward( spacing_tensor: torch.Tensor | None = None, ) -> torch.Tensor: """ - Forward pass through the DiffusionModelUNetMaisi. - Args: - x: Input tensor of shape (N, C, SpatialDims). - timesteps: Timestep tensor of shape (N,). - context: Optional context tensor of shape (N, 1, ContextDim). - class_labels: Optional class label tensor of shape (N,). - down_block_additional_residuals: Optional additional residual tensors for down blocks. - mid_block_additional_residual: Optional additional residual tensor for mid block. - top_region_index_tensor: Optional tensor for top region index of shape (N, 4). - bottom_region_index_tensor: Optional tensor for bottom region index of shape (N, 4). - spacing_tensor: Optional tensor for spacing information of shape (N, 3). - - Returns: - Output tensor of shape (N, C, SpatialDims). + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). + mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). """ - t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]).to(dtype=x.dtype) + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) emb = self.time_embed(t_emb) + # print(f't_emb: {t_emb}; timesteps {timesteps}.') + # print(f'emb: {torch.norm(emb, p=2)}; t_emb: {torch.norm(t_emb, p=2)}') + # print(f"emb: {torch.norm(emb)}") - if self.num_class_embeds is not None and class_labels is not None: - class_emb = self.class_embedding(class_labels).to(dtype=x.dtype) + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) emb = emb + class_emb - if self.input_top_region_index and top_region_index_tensor is not None: - emb = torch.cat((emb, self.top_region_index_layer(top_region_index_tensor)), dim=1) - if self.input_bottom_region_index and bottom_region_index_tensor is not None: - emb = torch.cat((emb, self.bottom_region_index_layer(bottom_region_index_tensor)), dim=1) - if self.input_spacing and spacing_tensor is not None: - emb = torch.cat((emb, self.spacing_layer(spacing_tensor)), dim=1) + # 3. input + if self.input_top_region_index: + _emb = self.top_region_index_layer(top_region_index_tensor) + # print(f"top_region_index_layer: {torch.norm(_emb)} {_emb.size()}") + # emb = emb + _emb.to(dtype=x.dtype) + emb = torch.cat((emb, _emb), dim=1) + # print(f'emb: {emb.size()}, {torch.norm(emb, p=2)}; top_region_index_tensor: {torch.norm(_emb, p=2)}') + if self.input_bottom_region_index: + _emb = self.bottom_region_index_layer(bottom_region_index_tensor) + # print(f"bottom_region_index_layer: {torch.norm(_emb)}") + # emb = emb + _emb.to(dtype=x.dtype) + emb = torch.cat((emb, _emb), dim=1) + # print(f'emb: {emb.size()}, {torch.norm(emb, p=2)}; bottom_region_index_tensor: {torch.norm(_emb, p=2)}') + if self.input_spacing: + _emb = self.spacing_layer(spacing_tensor) + # print(f"spacing_layer: {torch.norm(_emb)}") + # emb = emb + _emb.to(dtype=x.dtype) + emb = torch.cat((emb, _emb), dim=1) + # print(f'emb: {emb.size()}, {torch.norm(emb, p=2)}; spacing_tensor: {torch.norm(spacing_tensor, p=2)}') + # 3. initial convolution h = self.conv_in(x) - down_block_res_samples = [h] + # print(f"x: {torch.norm(x)}; h: {torch.norm(h)}") + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples: list[torch.Tensor] = [h] for downsample_block in self.down_blocks: h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) - down_block_res_samples.extend(res_samples) + for residual in res_samples: + down_block_res_samples.append(residual) + # Additional residual conections for Controlnets if down_block_additional_residuals is not None: - down_block_res_samples = [ - res + add_res for res, add_res in zip(down_block_res_samples, down_block_additional_residuals) - ] + new_down_block_res_samples = () + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + down_block_res_samples = new_down_block_res_samples + + # 5. mid h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # Additional residual conections for Controlnets if mid_block_additional_residual is not None: h = h + mid_block_additional_residual + # 6. up for upsample_block in self.up_blocks: res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] From f82bf6b06c2e0f7a235ab3e4cc65cb07e1946bf1 Mon Sep 17 00:00:00 2001 From: Dong Yang Date: Fri, 21 Jun 2024 23:19:53 -0600 Subject: [PATCH 19/36] update unet Signed-off-by: Dong Yang --- .../networks/diffusion_model_unet_maisi.py | 36 ++----------------- 1 file changed, 2 insertions(+), 34 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index e8aa1fe695..57f556ecdc 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -41,26 +41,12 @@ from monai.networks.blocks import Convolution, MLPBlock from monai.networks.layers.factories import Pool -from monai.utils import ensure_tuple_rep +from monai.utils import ensure_tuple_rep, optional_import -# To install xformers, use pip install xformers==0.0.16rc401 -if importlib.util.find_spec("xformers") is not None: - import xformers - import xformers.ops - - has_xformers = True -else: - xformers = None - has_xformers = False - - -# TODO: Use MONAI's optional_import -# from monai.utils import optional_import -# xformers, has_xformers = optional_import("xformers.ops", name="xformers") +xformers, has_xformers = optional_import("xformers.ops", name="xformers") __all__ = ["CustomDiffusionModelUNet"] - def zero_module(module: nn.Module) -> nn.Module: """ Zero out the parameters of a module and return it. @@ -469,8 +455,6 @@ def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_peri embedding_dim: the dimension of the output. max_period: controls the minimum frequency of the embeddings. """ - # print(f'max_period: {max_period}; timesteps: {torch.norm(timesteps.float(), p=2)}; embedding_dim: {embedding_dim}') - if timesteps.ndim != 1: raise ValueError("Timesteps should be a 1d-array") @@ -1778,19 +1762,16 @@ def __init__( new_time_embed_dim = time_embed_dim if self.input_top_region_index: - # self.top_region_index_layer = nn.Linear(4, time_embed_dim) self.top_region_index_layer = nn.Sequential( nn.Linear(4, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) ) new_time_embed_dim += time_embed_dim if self.input_bottom_region_index: - # self.bottom_region_index_layer = nn.Linear(4, time_embed_dim) self.bottom_region_index_layer = nn.Sequential( nn.Linear(4, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) ) new_time_embed_dim += time_embed_dim if self.input_spacing: - # self.spacing_layer = nn.Linear(3, time_embed_dim) self.spacing_layer = nn.Sequential( nn.Linear(3, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) ) @@ -1925,9 +1906,6 @@ def forward( # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=x.dtype) emb = self.time_embed(t_emb) - # print(f't_emb: {t_emb}; timesteps {timesteps}.') - # print(f'emb: {torch.norm(emb, p=2)}; t_emb: {torch.norm(t_emb, p=2)}') - # print(f"emb: {torch.norm(emb)}") # 2. class if self.num_class_embeds is not None: @@ -1940,26 +1918,16 @@ def forward( # 3. input if self.input_top_region_index: _emb = self.top_region_index_layer(top_region_index_tensor) - # print(f"top_region_index_layer: {torch.norm(_emb)} {_emb.size()}") - # emb = emb + _emb.to(dtype=x.dtype) emb = torch.cat((emb, _emb), dim=1) - # print(f'emb: {emb.size()}, {torch.norm(emb, p=2)}; top_region_index_tensor: {torch.norm(_emb, p=2)}') if self.input_bottom_region_index: _emb = self.bottom_region_index_layer(bottom_region_index_tensor) - # print(f"bottom_region_index_layer: {torch.norm(_emb)}") - # emb = emb + _emb.to(dtype=x.dtype) emb = torch.cat((emb, _emb), dim=1) - # print(f'emb: {emb.size()}, {torch.norm(emb, p=2)}; bottom_region_index_tensor: {torch.norm(_emb, p=2)}') if self.input_spacing: _emb = self.spacing_layer(spacing_tensor) - # print(f"spacing_layer: {torch.norm(_emb)}") - # emb = emb + _emb.to(dtype=x.dtype) emb = torch.cat((emb, _emb), dim=1) - # print(f'emb: {emb.size()}, {torch.norm(emb, p=2)}; spacing_tensor: {torch.norm(spacing_tensor, p=2)}') # 3. initial convolution h = self.conv_in(x) - # print(f"x: {torch.norm(x)}; h: {torch.norm(h)}") # 4. down if context is not None and self.with_conditioning is False: From 31d67955f1821e306ada0abcc962c4586d5a7ae8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 22 Jun 2024 05:20:21 +0000 Subject: [PATCH 20/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../apps/generation/maisi/networks/diffusion_model_unet_maisi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index 57f556ecdc..f5915be6c7 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -31,7 +31,6 @@ from __future__ import annotations -import importlib.util import math from collections.abc import Sequence From d423cd3e1d4d204ec5735aa0d8a23a11f4f02634 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Sat, 22 Jun 2024 05:39:00 +0000 Subject: [PATCH 21/36] update diffusion unet Signed-off-by: dongyang0122 --- .../networks/diffusion_model_unet_maisi.py | 1606 +---------------- 1 file changed, 18 insertions(+), 1588 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index f5915be6c7..b4e7472374 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -31,1602 +31,32 @@ from __future__ import annotations -import math +import importlib.util from collections.abc import Sequence import torch -import torch.nn.functional as F +from generative.networks.nets.diffusion_model_unet import ( + get_down_block, + get_mid_block, + get_timestep_embedding, + get_up_block, + zero_module, +) from torch import nn -from monai.networks.blocks import Convolution, MLPBlock -from monai.networks.layers.factories import Pool -from monai.utils import ensure_tuple_rep, optional_import +from monai.networks.blocks import Convolution +from monai.utils import ensure_tuple_rep -xformers, has_xformers = optional_import("xformers.ops", name="xformers") +if importlib.util.find_spec("xformers") is not None: + import xformers + import xformers.ops -__all__ = ["CustomDiffusionModelUNet"] + has_xformers = True +else: + xformers = None + has_xformers = False -def zero_module(module: nn.Module) -> nn.Module: - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -class CrossAttention(nn.Module): - """ - A cross attention layer. - - Args: - query_dim: number of channels in the query. - cross_attention_dim: number of channels in the context. - num_attention_heads: number of heads to use for multi-head attention. - num_head_channels: number of channels in each head. - dropout: dropout probability to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - query_dim: int, - cross_attention_dim: int | None = None, - num_attention_heads: int = 8, - num_head_channels: int = 64, - dropout: float = 0.0, - upcast_attention: bool = False, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.use_flash_attention = use_flash_attention - inner_dim = num_head_channels * num_attention_heads - cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - - self.scale = 1 / math.sqrt(num_head_channels) - self.num_heads = num_attention_heads - - self.upcast_attention = upcast_attention - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) - self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) - - def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: - """ - Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. - """ - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) - x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) - return x - - def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: - """Combine the output of the attention heads back into the hidden state dimension.""" - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) - x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) - return x - - def _memory_efficient_attention_xformers( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) - return x - - def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: - dtype = query.dtype - if self.upcast_attention: - query = query.float() - key = key.float() - - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - attention_probs = attention_scores.softmax(dim=-1) - attention_probs = attention_probs.to(dtype=dtype) - - x = torch.bmm(attention_probs, value) - return x - - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: - query = self.to_q(x) - context = context if context is not None else x - key = self.to_k(context) - value = self.to_v(context) - - # Multi-Head Attention - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - - if self.use_flash_attention: - x = self._memory_efficient_attention_xformers(query, key, value) - else: - x = self._attention(query, key, value) - - x = self.reshape_batch_dim_to_heads(x) - x = x.to(query.dtype) - - return self.to_out(x) - - -class BasicTransformerBlock(nn.Module): - """ - A basic Transformer block. - - Args: - num_channels: number of channels in the input and output. - num_attention_heads: number of heads to use for multi-head attention. - num_head_channels: number of channels in each attention head. - dropout: dropout probability to use. - cross_attention_dim: size of the context vector for cross attention. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - num_channels: int, - num_attention_heads: int, - num_head_channels: int, - dropout: float = 0.0, - cross_attention_dim: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.attn1 = CrossAttention( - query_dim=num_channels, - num_attention_heads=num_attention_heads, - num_head_channels=num_head_channels, - dropout=dropout, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) # is a self-attention - self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) - self.attn2 = CrossAttention( - query_dim=num_channels, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - num_head_channels=num_head_channels, - dropout=dropout, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) # is a self-attention if context is None - self.norm1 = nn.LayerNorm(num_channels) - self.norm2 = nn.LayerNorm(num_channels) - self.norm3 = nn.LayerNorm(num_channels) - - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: - # 1. Self-Attention - x = self.attn1(self.norm1(x)) + x - - # 2. Cross-Attention - x = self.attn2(self.norm2(x), context=context) + x - - # 3. Feed-forward - x = self.ff(self.norm3(x)) + x - return x - - -class SpatialTransformer(nn.Module): - """ - Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply - standard transformer action. Finally, reshape to image. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of channels in the input and output. - num_attention_heads: number of heads to use for multi-head attention. - num_head_channels: number of channels in each attention head. - num_layers: number of layers of Transformer blocks to use. - dropout: dropout probability to use. - norm_num_groups: number of groups for the normalization. - norm_eps: epsilon for the normalization. - cross_attention_dim: number of context dimensions to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - num_attention_heads: int, - num_head_channels: int, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - cross_attention_dim: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.spatial_dims = spatial_dims - self.in_channels = in_channels - inner_dim = num_attention_heads * num_head_channels - - self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) - - self.proj_in = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=inner_dim, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - num_channels=inner_dim, - num_attention_heads=num_attention_heads, - num_head_channels=num_head_channels, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) - for _ in range(num_layers) - ] - ) - - self.proj_out = zero_module( - Convolution( - spatial_dims=spatial_dims, - in_channels=inner_dim, - out_channels=in_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - ) - - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: - # note: if no context is given, cross-attention defaults to self-attention - batch = channel = height = width = depth = -1 - if self.spatial_dims == 2: - batch, channel, height, width = x.shape - if self.spatial_dims == 3: - batch, channel, height, width, depth = x.shape - - residual = x - x = self.norm(x) - x = self.proj_in(x) - - inner_dim = x.shape[1] - - if self.spatial_dims == 2: - x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - if self.spatial_dims == 3: - x = x.permute(0, 2, 3, 4, 1).reshape(batch, height * width * depth, inner_dim) - - for block in self.transformer_blocks: - x = block(x, context=context) - - if self.spatial_dims == 2: - x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - if self.spatial_dims == 3: - x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous() - - x = self.proj_out(x) - return x + residual - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. Uses three q, k, v linear layers to - compute attention. - - Args: - spatial_dims: number of spatial dimensions. - num_channels: number of input channels. - num_head_channels: number of channels in each attention head. - norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of - channels is divisible by this number. - norm_eps: epsilon value to use for the normalisation. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - num_channels: int, - num_head_channels: int | None = None, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.use_flash_attention = use_flash_attention - self.spatial_dims = spatial_dims - self.num_channels = num_channels - - self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 - self.scale = 1 / math.sqrt(num_channels / self.num_heads) - - self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) - - self.to_q = nn.Linear(num_channels, num_channels) - self.to_k = nn.Linear(num_channels, num_channels) - self.to_v = nn.Linear(num_channels, num_channels) - - self.proj_attn = nn.Linear(num_channels, num_channels) - - def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) - x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) - return x - - def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) - x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) - return x - - def _memory_efficient_attention_xformers( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) - return x - - def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - attention_probs = attention_scores.softmax(dim=-1) - x = torch.bmm(attention_probs, value) - return x - - def forward(self, x: torch.Tensor) -> torch.Tensor: - residual = x - - batch = channel = height = width = depth = -1 - if self.spatial_dims == 2: - batch, channel, height, width = x.shape - if self.spatial_dims == 3: - batch, channel, height, width, depth = x.shape - - # norm - x = self.norm(x) - - if self.spatial_dims == 2: - x = x.view(batch, channel, height * width).transpose(1, 2) - if self.spatial_dims == 3: - x = x.view(batch, channel, height * width * depth).transpose(1, 2) - - # proj to q, k, v - query = self.to_q(x) - key = self.to_k(x) - value = self.to_v(x) - - # Multi-Head Attention - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - - if self.use_flash_attention: - x = self._memory_efficient_attention_xformers(query, key, value) - else: - x = self._attention(query, key, value) - - x = self.reshape_batch_dim_to_heads(x) - x = x.to(query.dtype) - - if self.spatial_dims == 2: - x = x.transpose(-1, -2).reshape(batch, channel, height, width) - if self.spatial_dims == 3: - x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) - - return x + residual - - -def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor: - """ - Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic - Models" https://arxiv.org/abs/2006.11239. - - Args: - timesteps: a 1-D Tensor of N indices, one per batch element. - embedding_dim: the dimension of the output. - max_period: controls the minimum frequency of the embeddings. - """ - if timesteps.ndim != 1: - raise ValueError("Timesteps should be a 1d-array") - - half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) - freqs = torch.exp(exponent / half_dim) - - args = timesteps[:, None].float() * freqs[None, :] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - - # zero pad - if embedding_dim % 2 == 1: - embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0)) - - return embedding - - -class Downsample(nn.Module): - """ - Downsampling layer. - - Args: - spatial_dims: number of spatial dimensions. - num_channels: number of input channels. - use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is - False, the number of output channels must be the same as the number of input channels. - out_channels: number of output channels. - padding: controls the amount of implicit zero-paddings on both sides for padding number of points - for each dimension. - """ - - def __init__( - self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 - ) -> None: - super().__init__() - self.num_channels = num_channels - self.out_channels = out_channels or num_channels - self.use_conv = use_conv - if use_conv: - self.op = Convolution( - spatial_dims=spatial_dims, - in_channels=self.num_channels, - out_channels=self.out_channels, - strides=2, - kernel_size=3, - padding=padding, - conv_only=True, - ) - else: - if self.num_channels != self.out_channels: - raise ValueError("num_channels and out_channels must be equal when use_conv=False") - self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2) - - def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: - del emb - if x.shape[1] != self.num_channels: - raise ValueError( - f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels " - f"({self.num_channels})" - ) - return self.op(x) - - -class Upsample(nn.Module): - """ - Upsampling layer with an optional convolution. - - Args: - spatial_dims: number of spatial dimensions. - num_channels: number of input channels. - use_conv: if True uses Convolution instead of Pool average to perform downsampling. - out_channels: number of output channels. - padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each - dimension. - """ - - def __init__( - self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 - ) -> None: - super().__init__() - self.num_channels = num_channels - self.out_channels = out_channels or num_channels - self.use_conv = use_conv - if use_conv: - self.conv = Convolution( - spatial_dims=spatial_dims, - in_channels=self.num_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=padding, - conv_only=True, - ) - else: - self.conv = None - - def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: - del emb - if x.shape[1] != self.num_channels: - raise ValueError("Input channels should be equal to num_channels") - - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # https://github.com/pytorch/pytorch/issues/86679 - dtype = x.dtype - if dtype == torch.bfloat16: - x = x.to(torch.float32) - - x = F.interpolate(x, scale_factor=2.0, mode="nearest") - - # If the input is bfloat16, we cast back to bfloat16 - if dtype == torch.bfloat16: - x = x.to(dtype) - - if self.use_conv: - x = self.conv(x) - return x - - -class ResnetBlock(nn.Module): - """ - Residual block with timestep conditioning. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - temb_channels: number of timestep embedding channels. - out_channels: number of output channels. - up: if True, performs upsampling. - down: if True, performs downsampling. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - temb_channels: int, - out_channels: int | None = None, - up: bool = False, - down: bool = False, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - ) -> None: - super().__init__() - self.spatial_dims = spatial_dims - self.channels = in_channels - self.emb_channels = temb_channels - self.out_channels = out_channels or in_channels - self.up = up - self.down = down - - self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) - self.nonlinearity = nn.SiLU() - self.conv1 = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - - self.upsample = self.downsample = None - if self.up: - self.upsample = Upsample(spatial_dims, in_channels, use_conv=False) - elif down: - self.downsample = Downsample(spatial_dims, in_channels, use_conv=False) - - self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) - - self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=self.out_channels, eps=norm_eps, affine=True) - self.conv2 = zero_module( - Convolution( - spatial_dims=spatial_dims, - in_channels=self.out_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ) - - if self.out_channels == in_channels: - self.skip_connection = nn.Identity() - else: - self.skip_connection = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - - def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: - h = x - h = self.norm1(h) - h = self.nonlinearity(h) - - if self.upsample is not None: - if h.shape[0] >= 64: - x = x.contiguous() - h = h.contiguous() - x = self.upsample(x) - h = self.upsample(h) - elif self.downsample is not None: - x = self.downsample(x) - h = self.downsample(h) - - h = self.conv1(h) - - if self.spatial_dims == 2: - temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] - else: - temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] - h = h + temb - - h = self.norm2(h) - h = self.nonlinearity(h) - h = self.conv2(h) - - return self.skip_connection(x) + h - - -class DownBlock(nn.Module): - """ - Unet's down block containing resnet and downsamplers blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_downsample: if True add downsample block. - resblock_updown: if True use residual blocks for downsampling. - downsample_padding: padding used in the downsampling block. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_downsample: bool = True, - resblock_updown: bool = False, - downsample_padding: int = 1, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - - resnets = [] - - for i in range(num_res_blocks): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - if resblock_updown: - self.downsampler = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - down=True, - ) - else: - self.downsampler = Downsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - ) - else: - self.downsampler = None - - def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None - ) -> tuple[torch.Tensor, list[torch.Tensor]]: - del context - output_states = [] - - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb) - output_states.append(hidden_states) - - if self.downsampler is not None: - hidden_states = self.downsampler(hidden_states, temb) - output_states.append(hidden_states) - - return hidden_states, output_states - - -class AttnDownBlock(nn.Module): - """ - Unet's down block containing resnet, downsamplers and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_downsample: if True add downsample block. - resblock_updown: if True use residual blocks for downsampling. - downsample_padding: padding used in the downsampling block. - num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_downsample: bool = True, - resblock_updown: bool = False, - downsample_padding: int = 1, - num_head_channels: int = 1, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - - resnets = [] - attentions = [] - - for i in range(num_res_blocks): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - attentions.append( - AttentionBlock( - spatial_dims=spatial_dims, - num_channels=out_channels, - num_head_channels=num_head_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - if resblock_updown: - self.downsampler = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - down=True, - ) - else: - self.downsampler = Downsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - ) - else: - self.downsampler = None - - def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None - ) -> tuple[torch.Tensor, list[torch.Tensor]]: - del context - output_states = [] - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - output_states.append(hidden_states) - - if self.downsampler is not None: - hidden_states = self.downsampler(hidden_states, temb) - output_states.append(hidden_states) - - return hidden_states, output_states - - -class CrossAttnDownBlock(nn.Module): - """ - Unet's down block containing resnet, downsamplers and cross-attention blocks. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_downsample: if True add downsample block. - resblock_updown: if True use residual blocks for downsampling. - downsample_padding: padding used in the downsampling block. - num_head_channels: number of channels in each attention head. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_downsample: bool = True, - resblock_updown: bool = False, - downsample_padding: int = 1, - num_head_channels: int = 1, - transformer_num_layers: int = 1, - cross_attention_dim: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - - resnets = [] - attentions = [] - - for i in range(num_res_blocks): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - - attentions.append( - SpatialTransformer( - spatial_dims=spatial_dims, - in_channels=out_channels, - num_attention_heads=out_channels // num_head_channels, - num_head_channels=num_head_channels, - num_layers=transformer_num_layers, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout=dropout_cattn, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - if resblock_updown: - self.downsampler = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - down=True, - ) - else: - self.downsampler = Downsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - ) - else: - self.downsampler = None - - def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None - ) -> tuple[torch.Tensor, list[torch.Tensor]]: - output_states = [] - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=context) - output_states.append(hidden_states) - - if self.downsampler is not None: - hidden_states = self.downsampler(hidden_states, temb) - output_states.append(hidden_states) - - return hidden_states, output_states - - -class AttnMidBlock(nn.Module): - """ - Unet's mid block containing resnet and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - temb_channels: number of timestep embedding channels. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - temb_channels: int, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - num_head_channels: int = 1, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.attention = None - - self.resnet_1 = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - self.attention = AttentionBlock( - spatial_dims=spatial_dims, - num_channels=in_channels, - num_head_channels=num_head_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - - self.resnet_2 = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - - def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None - ) -> torch.Tensor: - del context - hidden_states = self.resnet_1(hidden_states, temb) - hidden_states = self.attention(hidden_states) - hidden_states = self.resnet_2(hidden_states, temb) - - return hidden_states - - -class CrossAttnMidBlock(nn.Module): - """ - Unet's mid block containing resnet and cross-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - temb_channels: number of timestep embedding channels - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - num_head_channels: number of channels in each attention head. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - temb_channels: int, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - num_head_channels: int = 1, - transformer_num_layers: int = 1, - cross_attention_dim: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, - ) -> None: - super().__init__() - self.attention = None - - self.resnet_1 = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - self.attention = SpatialTransformer( - spatial_dims=spatial_dims, - in_channels=in_channels, - num_attention_heads=in_channels // num_head_channels, - num_head_channels=num_head_channels, - num_layers=transformer_num_layers, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout=dropout_cattn, - ) - self.resnet_2 = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - - def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None - ) -> torch.Tensor: - hidden_states = self.resnet_1(hidden_states, temb) - hidden_states = self.attention(hidden_states, context=context) - hidden_states = self.resnet_2(hidden_states, temb) - - return hidden_states - - -class UpBlock(nn.Module): - """ - Unet's up block containing resnet and upsamplers blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_upsample: bool = True, - resblock_updown: bool = False, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - resnets = [] - - for i in range(num_res_blocks): - res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock( - spatial_dims=spatial_dims, - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - if resblock_updown: - self.upsampler = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - up=True, - ) - else: - self.upsampler = Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels - ) - else: - self.upsampler = None - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_list: list[torch.Tensor], - temb: torch.Tensor, - context: torch.Tensor | None = None, - ) -> torch.Tensor: - del context - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_list[-1] - res_hidden_states_list = res_hidden_states_list[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - - if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states, temb) - - return hidden_states - - -class AttnUpBlock(nn.Module): - """ - Unet's up block containing resnet, upsamplers, and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_upsample: bool = True, - resblock_updown: bool = False, - num_head_channels: int = 1, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - - resnets = [] - attentions = [] - - for i in range(num_res_blocks): - res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock( - spatial_dims=spatial_dims, - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - attentions.append( - AttentionBlock( - spatial_dims=spatial_dims, - num_channels=out_channels, - num_head_channels=num_head_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.attentions = nn.ModuleList(attentions) - - if add_upsample: - if resblock_updown: - self.upsampler = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - up=True, - ) - else: - self.upsampler = Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels - ) - else: - self.upsampler = None - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_list: list[torch.Tensor], - temb: torch.Tensor, - context: torch.Tensor | None = None, - ) -> torch.Tensor: - del context - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_list[-1] - res_hidden_states_list = res_hidden_states_list[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - - if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states, temb) - - return hidden_states - - -class CrossAttnUpBlock(nn.Module): - """ - Unet's up block containing resnet, upsamplers, and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - num_head_channels: number of channels in each attention head. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_upsample: bool = True, - resblock_updown: bool = False, - num_head_channels: int = 1, - transformer_num_layers: int = 1, - cross_attention_dim: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - - resnets = [] - attentions = [] - - for i in range(num_res_blocks): - res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock( - spatial_dims=spatial_dims, - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - attentions.append( - SpatialTransformer( - spatial_dims=spatial_dims, - in_channels=out_channels, - num_attention_heads=out_channels // num_head_channels, - num_head_channels=num_head_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout=dropout_cattn, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - if resblock_updown: - self.upsampler = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - up=True, - ) - else: - self.upsampler = Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels - ) - else: - self.upsampler = None - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_list: list[torch.Tensor], - temb: torch.Tensor, - context: torch.Tensor | None = None, - ) -> torch.Tensor: - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_list[-1] - res_hidden_states_list = res_hidden_states_list[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=context) - - if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states, temb) - - return hidden_states - - -def get_down_block( - spatial_dims: int, - in_channels: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int, - norm_num_groups: int, - norm_eps: float, - add_downsample: bool, - resblock_updown: bool, - with_attn: bool, - with_cross_attn: bool, - num_head_channels: int, - transformer_num_layers: int, - cross_attention_dim: int | None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, -) -> nn.Module: - if with_attn: - return AttnDownBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_downsample=add_downsample, - resblock_updown=resblock_updown, - num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, - ) - elif with_cross_attn: - return CrossAttnDownBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_downsample=add_downsample, - resblock_updown=resblock_updown, - num_head_channels=num_head_channels, - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout_cattn=dropout_cattn, - ) - else: - return DownBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_downsample=add_downsample, - resblock_updown=resblock_updown, - ) - - -def get_mid_block( - spatial_dims: int, - in_channels: int, - temb_channels: int, - norm_num_groups: int, - norm_eps: float, - with_conditioning: bool, - num_head_channels: int, - transformer_num_layers: int, - cross_attention_dim: int | None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, -) -> nn.Module: - if with_conditioning: - return CrossAttnMidBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - num_head_channels=num_head_channels, - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout_cattn=dropout_cattn, - ) - else: - return AttnMidBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, - ) - - -def get_up_block( - spatial_dims: int, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int, - norm_num_groups: int, - norm_eps: float, - add_upsample: bool, - resblock_updown: bool, - with_attn: bool, - with_cross_attn: bool, - num_head_channels: int, - transformer_num_layers: int, - cross_attention_dim: int | None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, -) -> nn.Module: - if with_attn: - return AttnUpBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - prev_output_channel=prev_output_channel, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_upsample=add_upsample, - resblock_updown=resblock_updown, - num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, - ) - elif with_cross_attn: - return CrossAttnUpBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - prev_output_channel=prev_output_channel, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_upsample=add_upsample, - resblock_updown=resblock_updown, - num_head_channels=num_head_channels, - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout_cattn=dropout_cattn, - ) - else: - return UpBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - prev_output_channel=prev_output_channel, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_upsample=add_upsample, - resblock_updown=resblock_updown, - ) +__all__ = ["DiffusionModelUNetMaisi"] class DiffusionModelUNetMaisi(nn.Module): From 54fd06819729fd564748eae882379b35a839ce1b Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Sat, 22 Jun 2024 05:44:10 +0000 Subject: [PATCH 22/36] update diffusion unet Signed-off-by: dongyang0122 --- .../networks/diffusion_model_unet_maisi.py | 60 +++++++++++-------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index b4e7472374..520d8596d6 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -61,29 +61,31 @@ class DiffusionModelUNetMaisi(nn.Module): """ - Unet network with timestep embedding and attention mechanisms for conditioning based on + U-Net network with timestep embedding and attention mechanisms for conditioning based on Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - num_res_blocks: number of residual blocks (see ResnetBlock) per level. - num_channels: tuple of block output channels. - attention_levels: list of levels to add attention. - norm_num_groups: number of groups for the normalization. - norm_eps: epsilon for the normalization. - resblock_updown: if True use residual blocks for up/downsampling. - num_head_channels: number of channels in each attention head. - with_conditioning: if True add spatial transformers to perform conditioning. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` - classes. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + spatial_dims: Number of spatial dimensions. + in_channels: Number of input channels. + out_channels: Number of output channels. + num_res_blocks: Number of residual blocks (see ResnetBlock) per level. Can be a single integer or a sequence of integers. + num_channels: Tuple of block output channels. + attention_levels: List of levels to add attention. + norm_num_groups: Number of groups for the normalization. + norm_eps: Epsilon for the normalization. + resblock_updown: If True, use residual blocks for up/downsampling. + num_head_channels: Number of channels in each attention head. Can be a single integer or a sequence of integers. + with_conditioning: If True, add spatial transformers to perform conditioning. + transformer_num_layers: Number of layers of Transformer blocks to use. + cross_attention_dim: Number of context dimensions to use. + num_class_embeds: If specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. + upcast_attention: If True, upcast attention operations to full precision. + use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: If different from zero, this will be the dropout value for the cross-attention layers. + input_top_region_index: If True, use top region index input. + input_bottom_region_index: If True, use bottom region index input. + input_spacing: If True, use spacing input. """ def __init__( @@ -319,13 +321,21 @@ def forward( spacing_tensor: torch.Tensor | None = None, ) -> torch.Tensor: """ + Forward pass through the UNet model. + Args: - x: input tensor (N, C, SpatialDims). - timesteps: timestep tensor (N,). - context: context tensor (N, 1, ContextDim). - class_labels: context tensor (N, ). - down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). - mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). + x: Input tensor of shape (N, C, SpatialDims). + timesteps: Timestep tensor of shape (N,). + context: Context tensor of shape (N, 1, ContextDim). + class_labels: Class labels tensor of shape (N,). + down_block_additional_residuals: Additional residual tensors for down blocks of shape (N, C, FeatureMapsDims). + mid_block_additional_residual: Additional residual tensor for mid block of shape (N, C, FeatureMapsDims). + top_region_index_tensor: Tensor representing top region index of shape (N, 4). + bottom_region_index_tensor: Tensor representing bottom region index of shape (N, 4). + spacing_tensor: Tensor representing spacing of shape (N, 3). + + Returns: + A tensor representing the output of the UNet model. """ # 1. time t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) From 03d7de21c63072f0adeebc782ea0daff26d6b692 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Sat, 22 Jun 2024 06:17:49 +0000 Subject: [PATCH 23/36] update Signed-off-by: dongyang0122 --- .../generation/maisi/inferers/__init__.py | 4 -- .../maisi/inferers/inferer_maisi.py | 16 +----- tests/test_diffusion_inferer_maisi.py | 48 +++++++++--------- tests/test_latent_diffusion_inferer_maisi.py | 50 +++++++++---------- 4 files changed, 51 insertions(+), 67 deletions(-) diff --git a/monai/apps/generation/maisi/inferers/__init__.py b/monai/apps/generation/maisi/inferers/__init__.py index 5bdca05bd0..1e97f89407 100644 --- a/monai/apps/generation/maisi/inferers/__init__.py +++ b/monai/apps/generation/maisi/inferers/__init__.py @@ -8,7 +8,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import annotations - -from .inferer_maisi import DiffusionInfererMaisi, LatentDiffusionInfererMaisi diff --git a/monai/apps/generation/maisi/inferers/inferer_maisi.py b/monai/apps/generation/maisi/inferers/inferer_maisi.py index dce5491319..bbede2a416 100644 --- a/monai/apps/generation/maisi/inferers/inferer_maisi.py +++ b/monai/apps/generation/maisi/inferers/inferer_maisi.py @@ -12,12 +12,10 @@ from __future__ import annotations from collections.abc import Callable -from functools import partial import torch import torch.nn as nn from generative.inferers.inferer import DiffusionInferer -from generative.networks.nets import VQVAE, SPADEDiffusionModelUNet from monai.data import decollate_batch from monai.utils import optional_import @@ -158,7 +156,7 @@ def sample( class LatentDiffusionInfererMaisi(DiffusionInfererMaisi): """ - LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can + LatentDiffusionInferer takes a stage 1 model (AutoencoderKL), diffusion model, and a scheduler, and can be used to perform a signal forward pass for a training iteration, and sample from the model. Args: @@ -286,8 +284,6 @@ def get_likelihood( verbose: bool = True, resample_latent_likelihoods: bool = False, resample_interpolation_mode: str = "nearest", - seg: torch.Tensor | None = None, - quantized: bool = True, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Computes the log-likelihoods of the latent representations of the input. @@ -306,11 +302,7 @@ def get_likelihood( resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial dimension as the input images. resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', - or 'trilinear; - seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model - is instance of SPADEAutoencoderKL, segmentation must be provided. - quantized: if autoencoder_model is a VQVAE, quantized controls whether the latents to the LDM - are quantized or not. + or 'trilinear. """ if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): raise ValueError( @@ -318,16 +310,12 @@ def get_likelihood( ) autoencode = autoencoder_model.encode_stage_2_inputs - if isinstance(autoencoder_model, VQVAE): - autoencode = partial(autoencoder_model.encode_stage_2_inputs, quantized=quantized) latents = autoencode(inputs) * self.scale_factor if self.ldm_latent_shape is not None: latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) get_likelihood = super().get_likelihood - if isinstance(diffusion_model, SPADEDiffusionModelUNet): - get_likelihood = partial(super().get_likelihood, seg=seg) outputs = get_likelihood( inputs=latents, diff --git a/tests/test_diffusion_inferer_maisi.py b/tests/test_diffusion_inferer_maisi.py index ee5d1ec728..52e51294fb 100644 --- a/tests/test_diffusion_inferer_maisi.py +++ b/tests/test_diffusion_inferer_maisi.py @@ -15,11 +15,11 @@ from unittest import skipUnless import torch +from generative.networks.schedulers import DDIMScheduler, DDPMScheduler from parameterized import parameterized from monai.apps.generation.maisi.inferers.inferer_maisi import DiffusionInfererMaisi from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi -from monai.networks.schedulers import DDIMScheduler, DDPMScheduler from monai.utils import optional_import _, has_scipy = optional_import("scipy") @@ -31,7 +31,7 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "channels": [8], + "num_channels": [8], "norm_num_groups": 8, "attention_levels": [True], "num_res_blocks": 1, @@ -44,7 +44,7 @@ "spatial_dims": 3, "in_channels": 1, "out_channels": 1, - "channels": [8], + "num_channels": [8], "norm_num_groups": 8, "attention_levels": [True], "num_res_blocks": 1, @@ -69,9 +69,9 @@ def test_call(self, model_params, input_shape): inferer = DiffusionInfererMaisi(scheduler=scheduler) scheduler.set_timesteps(num_inference_steps=10) timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() - top_region_index_tensor = torch.randn(input_shape).to(device) - bottom_region_index_tensor = torch.randn(input_shape).to(device) - spacing_tensor = torch.randn(input_shape).to(device) + top_region_index_tensor = torch.rand((1, 4)).to(device) + bottom_region_index_tensor = torch.rand((1, 4)).to(device) + spacing_tensor = torch.rand((1, 3)).to(device) sample = inferer( inputs=input, noise=noise, @@ -94,9 +94,9 @@ def test_sample_intermediates(self, model_params, input_shape): scheduler = DDPMScheduler(num_train_timesteps=10) inferer = DiffusionInfererMaisi(scheduler=scheduler) scheduler.set_timesteps(num_inference_steps=10) - top_region_index_tensor = torch.randn(input_shape).to(device) - bottom_region_index_tensor = torch.randn(input_shape).to(device) - spacing_tensor = torch.randn(input_shape).to(device) + top_region_index_tensor = torch.rand((1, 4)).to(device) + bottom_region_index_tensor = torch.rand((1, 4)).to(device) + spacing_tensor = torch.rand((1, 3)).to(device) sample, intermediates = inferer.sample( input_noise=noise, diffusion_model=model, @@ -120,9 +120,9 @@ def test_ddpm_sampler(self, model_params, input_shape): scheduler = DDPMScheduler(num_train_timesteps=1000) inferer = DiffusionInfererMaisi(scheduler=scheduler) scheduler.set_timesteps(num_inference_steps=10) - top_region_index_tensor = torch.randn(input_shape).to(device) - bottom_region_index_tensor = torch.randn(input_shape).to(device) - spacing_tensor = torch.randn(input_shape).to(device) + top_region_index_tensor = torch.rand((1, 4)).to(device) + bottom_region_index_tensor = torch.rand((1, 4)).to(device) + spacing_tensor = torch.rand((1, 3)).to(device) sample, intermediates = inferer.sample( input_noise=noise, diffusion_model=model, @@ -146,9 +146,9 @@ def test_ddim_sampler(self, model_params, input_shape): scheduler = DDIMScheduler(num_train_timesteps=1000) inferer = DiffusionInfererMaisi(scheduler=scheduler) scheduler.set_timesteps(num_inference_steps=10) - top_region_index_tensor = torch.randn(input_shape).to(device) - bottom_region_index_tensor = torch.randn(input_shape).to(device) - spacing_tensor = torch.randn(input_shape).to(device) + top_region_index_tensor = torch.rand((1, 4)).to(device) + bottom_region_index_tensor = torch.rand((1, 4)).to(device) + spacing_tensor = torch.rand((1, 3)).to(device) sample, intermediates = inferer.sample( input_noise=noise, diffusion_model=model, @@ -175,9 +175,9 @@ def test_sampler_conditioned(self, model_params, input_shape): inferer = DiffusionInfererMaisi(scheduler=scheduler) scheduler.set_timesteps(num_inference_steps=10) conditioning = torch.randn([input_shape[0], 1, 3]).to(device) - top_region_index_tensor = torch.randn(input_shape).to(device) - bottom_region_index_tensor = torch.randn(input_shape).to(device) - spacing_tensor = torch.randn(input_shape).to(device) + top_region_index_tensor = torch.rand((1, 4)).to(device) + bottom_region_index_tensor = torch.rand((1, 4)).to(device) + spacing_tensor = torch.rand((1, 3)).to(device) sample, intermediates = inferer.sample( input_noise=noise, diffusion_model=model, @@ -210,9 +210,9 @@ def test_sampler_conditioned_concat(self, model_params, input_shape): scheduler = DDIMScheduler(num_train_timesteps=1000) inferer = DiffusionInfererMaisi(scheduler=scheduler) scheduler.set_timesteps(num_inference_steps=10) - top_region_index_tensor = torch.randn(input_shape).to(device) - bottom_region_index_tensor = torch.randn(input_shape).to(device) - spacing_tensor = torch.randn(input_shape).to(device) + top_region_index_tensor = torch.rand((1, 4)).to(device) + bottom_region_index_tensor = torch.rand((1, 4)).to(device) + spacing_tensor = torch.rand((1, 3)).to(device) sample, intermediates = inferer.sample( input_noise=noise, diffusion_model=model, @@ -248,9 +248,9 @@ def test_call_conditioned_concat(self, model_params, input_shape): inferer = DiffusionInfererMaisi(scheduler=scheduler) scheduler.set_timesteps(num_inference_steps=10) timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() - top_region_index_tensor = torch.randn(input_shape).to(device) - bottom_region_index_tensor = torch.randn(input_shape).to(device) - spacing_tensor = torch.randn(input_shape).to(device) + top_region_index_tensor = torch.rand((1, 4)).to(device) + bottom_region_index_tensor = torch.rand((1, 4)).to(device) + spacing_tensor = torch.rand((1, 3)).to(device) sample = inferer( inputs=input, noise=noise, diff --git a/tests/test_latent_diffusion_inferer_maisi.py b/tests/test_latent_diffusion_inferer_maisi.py index 3495da3cda..926564ea0f 100644 --- a/tests/test_latent_diffusion_inferer_maisi.py +++ b/tests/test_latent_diffusion_inferer_maisi.py @@ -15,12 +15,12 @@ from unittest import skipUnless import torch -from generative.networks.nets.autoencoderkl import AutoencoderKL +from generative.networks.nets import AutoencoderKL +from generative.networks.schedulers import DDPMScheduler from parameterized import parameterized from monai.apps.generation.maisi.inferers.inferer_maisi import LatentDiffusionInfererMaisi from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi -from monai.networks.schedulers import DDPMScheduler from monai.utils import optional_import _, has_einops = optional_import("einops") @@ -31,7 +31,7 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "channels": (4, 4), + "num_channels": (4, 4), "latent_channels": 3, "attention_levels": [False, False], "num_res_blocks": 1, @@ -44,7 +44,7 @@ "spatial_dims": 2, "in_channels": 3, "out_channels": 3, - "channels": [4, 4], + "num_channels": [4, 4], "norm_num_groups": 4, "attention_levels": [False, False], "num_res_blocks": 1, @@ -79,9 +79,9 @@ def test_prediction_shape( input = torch.randn(input_shape).to(device) noise = torch.randn(latent_shape).to(device) - top_region_index_tensor = torch.randint(0, 2, input_shape).to(device) - bottom_region_index_tensor = torch.randint(0, 2, input_shape).to(device) - spacing_tensor = torch.randn(input_shape).to(device) + top_region_index_tensor = torch.rand((1, 4)).to(device) + bottom_region_index_tensor = torch.rand((1, 4)).to(device) + spacing_tensor = torch.rand((1, 3)).to(device) scheduler = DDPMScheduler(num_train_timesteps=10) inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) scheduler.set_timesteps(num_inference_steps=10) @@ -120,9 +120,9 @@ def test_sample_shape( stage_2.eval() noise = torch.randn(latent_shape).to(device) - top_region_index_tensor = torch.randint(0, 2, input_shape).to(device) - bottom_region_index_tensor = torch.randint(0, 2, input_shape).to(device) - spacing_tensor = torch.randn(input_shape).to(device) + top_region_index_tensor = torch.rand((1, 4)).to(device) + bottom_region_index_tensor = torch.rand((1, 4)).to(device) + spacing_tensor = torch.rand((1, 3)).to(device) scheduler = DDPMScheduler(num_train_timesteps=10) inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) scheduler.set_timesteps(num_inference_steps=10) @@ -159,9 +159,9 @@ def test_sample_intermediates( stage_2.eval() noise = torch.randn(latent_shape).to(device) - top_region_index_tensor = torch.randint(0, 2, input_shape).to(device) - bottom_region_index_tensor = torch.randint(0, 2, input_shape).to(device) - spacing_tensor = torch.randn(input_shape).to(device) + top_region_index_tensor = torch.rand((1, 4)).to(device) + bottom_region_index_tensor = torch.rand((1, 4)).to(device) + spacing_tensor = torch.rand((1, 3)).to(device) scheduler = DDPMScheduler(num_train_timesteps=10) inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) scheduler.set_timesteps(num_inference_steps=10) @@ -201,9 +201,9 @@ def test_get_likelihoods( stage_2.eval() input = torch.randn(input_shape).to(device) - top_region_index_tensor = torch.randint(0, 2, input_shape).to(device) - bottom_region_index_tensor = torch.randint(0, 2, input_shape).to(device) - spacing_tensor = torch.randn(input_shape).to(device) + top_region_index_tensor = torch.rand((1, 4)).to(device) + bottom_region_index_tensor = torch.rand((1, 4)).to(device) + spacing_tensor = torch.rand((1, 3)).to(device) scheduler = DDPMScheduler(num_train_timesteps=10) inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) scheduler.set_timesteps(num_inference_steps=10) @@ -242,9 +242,9 @@ def test_resample_likelihoods( stage_2.eval() input = torch.randn(input_shape).to(device) - top_region_index_tensor = torch.randint(0, 2, input_shape).to(device) - bottom_region_index_tensor = torch.randint(0, 2, input_shape).to(device) - spacing_tensor = torch.randn(input_shape).to(device) + top_region_index_tensor = torch.rand((1, 4)).to(device) + bottom_region_index_tensor = torch.rand((1, 4)).to(device) + spacing_tensor = torch.rand((1, 3)).to(device) scheduler = DDPMScheduler(num_train_timesteps=10) inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) scheduler.set_timesteps(num_inference_steps=10) @@ -291,9 +291,9 @@ def test_prediction_shape_conditioned_concat( conditioning_shape = list(latent_shape) conditioning_shape[1] = n_concat_channel conditioning = torch.randn(conditioning_shape).to(device) - top_region_index_tensor = torch.randint(0, 2, input_shape).to(device) - bottom_region_index_tensor = torch.randint(0, 2, input_shape).to(device) - spacing_tensor = torch.randn(input_shape).to(device) + top_region_index_tensor = torch.rand((1, 4)).to(device) + bottom_region_index_tensor = torch.rand((1, 4)).to(device) + spacing_tensor = torch.rand((1, 3)).to(device) scheduler = DDPMScheduler(num_train_timesteps=10) inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) @@ -341,9 +341,9 @@ def test_sample_shape_conditioned_concat( conditioning_shape = list(latent_shape) conditioning_shape[1] = n_concat_channel conditioning = torch.randn(conditioning_shape).to(device) - top_region_index_tensor = torch.randint(0, 2, input_shape).to(device) - bottom_region_index_tensor = torch.randint(0, 2, input_shape).to(device) - spacing_tensor = torch.randn(input_shape).to(device) + top_region_index_tensor = torch.rand((1, 4)).to(device) + bottom_region_index_tensor = torch.rand((1, 4)).to(device) + spacing_tensor = torch.rand((1, 3)).to(device) scheduler = DDPMScheduler(num_train_timesteps=10) inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) From 5324e0e8d5414f142c364569c731441cfd870d30 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Sat, 22 Jun 2024 06:52:14 +0000 Subject: [PATCH 24/36] update Signed-off-by: dongyang0122 --- .../generation/maisi/inferers/__init__.py | 10 - .../maisi/inferers/inferer_maisi.py | 335 ---------------- tests/test_diffusion_inferer_maisi.py | 269 ------------- tests/test_latent_diffusion_inferer_maisi.py | 367 ------------------ 4 files changed, 981 deletions(-) delete mode 100644 monai/apps/generation/maisi/inferers/__init__.py delete mode 100644 monai/apps/generation/maisi/inferers/inferer_maisi.py delete mode 100644 tests/test_diffusion_inferer_maisi.py delete mode 100644 tests/test_latent_diffusion_inferer_maisi.py diff --git a/monai/apps/generation/maisi/inferers/__init__.py b/monai/apps/generation/maisi/inferers/__init__.py deleted file mode 100644 index 1e97f89407..0000000000 --- a/monai/apps/generation/maisi/inferers/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/monai/apps/generation/maisi/inferers/inferer_maisi.py b/monai/apps/generation/maisi/inferers/inferer_maisi.py deleted file mode 100644 index bbede2a416..0000000000 --- a/monai/apps/generation/maisi/inferers/inferer_maisi.py +++ /dev/null @@ -1,335 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from collections.abc import Callable - -import torch -import torch.nn as nn -from generative.inferers.inferer import DiffusionInferer - -from monai.data import decollate_batch -from monai.utils import optional_import - -tqdm, has_tqdm = optional_import("tqdm", name="tqdm") - - -class DiffusionInfererMaisi(DiffusionInferer): - """ - DiffusionInfererMaisi takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass - for a training iteration, and sample from the model. - - Args: - scheduler: diffusion scheduler. - """ - - def __init__(self, scheduler: nn.Module) -> None: - super().__init__(scheduler=scheduler) - - def __call__( - self, - inputs: torch.Tensor, - diffusion_model: Callable[..., torch.Tensor], - noise: torch.Tensor, - timesteps: torch.Tensor, - condition: torch.Tensor | None = None, - mode: str = "crossattn", - top_region_index_tensor: torch.Tensor | None = None, - bottom_region_index_tensor: torch.Tensor | None = None, - spacing_tensor: torch.Tensor | None = None, - ) -> torch.Tensor: - """ - Implements the forward pass for a supervised training iteration. - - Args: - inputs: Input image to which noise is added. - diffusion_model: diffusion model. - noise: random noise, of the same shape as the input. - timesteps: random timesteps. - condition: Conditioning for network input. - mode: Conditioning mode for the network. - top_region_index_tensor: tensor for top region index. - bottom_region_index_tensor: tensor for bottom region index. - spacing_tensor: tensor for spacing. - """ - if mode not in ["crossattn", "concat"]: - raise NotImplementedError(f"{mode} condition is not supported") - - noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) - if mode == "concat": - noisy_image = torch.cat([noisy_image, condition], dim=1) - condition = None - prediction = diffusion_model( - x=noisy_image, - timesteps=timesteps, - context=condition, - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - ) - - return prediction - - @torch.no_grad() - def sample( - self, - input_noise: torch.Tensor, - diffusion_model: Callable[..., torch.Tensor], - scheduler: Callable[..., torch.Tensor] | None = None, - save_intermediates: bool | None = False, - intermediate_steps: int | None = 100, - conditioning: torch.Tensor | None = None, - mode: str = "crossattn", - verbose: bool = True, - top_region_index_tensor: torch.Tensor | None = None, - bottom_region_index_tensor: torch.Tensor | None = None, - spacing_tensor: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: - """ - Args: - input_noise: random noise, of the same shape as the desired sample. - diffusion_model: model to sample from. - scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. - save_intermediates: whether to return intermediates along the sampling change. - intermediate_steps: if save_intermediates is True, saves every n steps. - conditioning: Conditioning for network input. - mode: Conditioning mode for the network. - verbose: if true, prints the progression bar of the sampling process. - top_region_index_tensor: tensor for top region index. - bottom_region_index_tensor: tensor for bottom region index. - spacing_tensor: tensor for spacing. - """ - if mode not in ["crossattn", "concat"]: - raise NotImplementedError(f"{mode} condition is not supported") - - if not scheduler: - scheduler = self.scheduler - image = input_noise - if verbose and has_tqdm: - progress_bar = tqdm(scheduler.timesteps) - else: - progress_bar = iter(scheduler.timesteps) - intermediates = [] - - for t in progress_bar: - - # 1. predict noise model_output - if mode == "concat": - model_input = torch.cat([image, conditioning], dim=1) - model_output = diffusion_model( - model_input, - timesteps=torch.Tensor((t,)).to(input_noise.device), - context=None, - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - ) - else: - model_output = diffusion_model( - image, - timesteps=torch.Tensor((t,)).to(input_noise.device), - context=conditioning, - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - ) - - # 2. compute previous image: x_t -> x_t-1 - image, _ = scheduler.step(model_output, t, image) - if save_intermediates and t % intermediate_steps == 0: - intermediates.append(image) - - if save_intermediates: - return image, intermediates - else: - return image - - -class LatentDiffusionInfererMaisi(DiffusionInfererMaisi): - """ - LatentDiffusionInferer takes a stage 1 model (AutoencoderKL), diffusion model, and a scheduler, and can - be used to perform a signal forward pass for a training iteration, and sample from the model. - - Args: - scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents. - scale_factor: scale factor to multiply the values of the latent representation before processing it by the - second stage. - """ - - def __init__(self, scheduler: nn.Module, scale_factor: float = 1.0) -> None: - super().__init__(scheduler=scheduler) - self.scale_factor = scale_factor - - def __call__( - self, - inputs: torch.Tensor, - autoencoder_model: Callable[..., torch.Tensor], - diffusion_model: Callable[..., torch.Tensor], - noise: torch.Tensor, - timesteps: torch.Tensor, - condition: torch.Tensor | None = None, - mode: str = "crossattn", - top_region_index_tensor: torch.Tensor | None = None, - bottom_region_index_tensor: torch.Tensor | None = None, - spacing_tensor: torch.Tensor | None = None, - ) -> torch.Tensor: - """ - Implements the forward pass for a supervised training iteration. - - Args: - inputs: input image to which the latent representation will be extracted and noise is added. - autoencoder_model: first stage model. - diffusion_model: diffusion model. - noise: random noise, of the same shape as the latent representation. - timesteps: random timesteps. - condition: conditioning for network input. - mode: Conditioning mode for the network. - """ - with torch.no_grad(): - latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor - - prediction = super().__call__( - inputs=latent, - diffusion_model=diffusion_model, - noise=noise, - timesteps=timesteps, - condition=condition, - mode=mode, - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - ) - - return prediction - - @torch.no_grad() - def sample( - self, - input_noise: torch.Tensor, - autoencoder_model: Callable[..., torch.Tensor], - diffusion_model: Callable[..., torch.Tensor], - scheduler: Callable[..., torch.Tensor] | None = None, - save_intermediates: bool | None = False, - intermediate_steps: int | None = 100, - conditioning: torch.Tensor | None = None, - mode: str = "crossattn", - verbose: bool = True, - top_region_index_tensor: torch.Tensor | None = None, - bottom_region_index_tensor: torch.Tensor | None = None, - spacing_tensor: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: - """ - Args: - input_noise: random noise, of the same shape as the desired latent representation. - autoencoder_model: first stage model. - diffusion_model: model to sample from. - scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. - save_intermediates: whether to return intermediates along the sampling change - intermediate_steps: if save_intermediates is True, saves every n steps - conditioning: Conditioning for network input. - mode: Conditioning mode for the network. - verbose: if true, prints the progression bar of the sampling process. - """ - outputs = super().sample( - input_noise=input_noise, - diffusion_model=diffusion_model, - scheduler=scheduler, - save_intermediates=save_intermediates, - intermediate_steps=intermediate_steps, - conditioning=conditioning, - mode=mode, - verbose=verbose, - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - ) - - if save_intermediates: - latent, latent_intermediates = outputs - else: - latent = outputs - - image = autoencoder_model.decode_stage_2_outputs(latent / self.scale_factor) - - if save_intermediates: - intermediates = [] - for latent_intermediate in latent_intermediates: - intermediates.append(autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor)) - return image, intermediates - - else: - return image - - @torch.no_grad() - def get_likelihood( - self, - inputs: torch.Tensor, - autoencoder_model: Callable[..., torch.Tensor], - diffusion_model: Callable[..., torch.Tensor], - scheduler: Callable[..., torch.Tensor] | None = None, - save_intermediates: bool | None = False, - conditioning: torch.Tensor | None = None, - mode: str = "crossattn", - original_input_range: tuple | None = (0, 255), - scaled_input_range: tuple | None = (0, 1), - verbose: bool = True, - resample_latent_likelihoods: bool = False, - resample_interpolation_mode: str = "nearest", - ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: - """ - Computes the log-likelihoods of the latent representations of the input. - - Args: - inputs: input images, NxCxHxW[xD] - autoencoder_model: first stage model. - diffusion_model: model to compute likelihood from - scheduler: diffusion scheduler. If none provided will use the class attribute scheduler - save_intermediates: save the intermediate spatial KL maps - conditioning: Conditioning for network input. - mode: Conditioning mode for the network. - original_input_range: the [min,max] intensity range of the input data before any scaling was applied. - scaled_input_range: the [min,max] intensity range of the input data after scaling. - verbose: if true, prints the progression bar of the sampling process. - resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial - dimension as the input images. - resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', - or 'trilinear. - """ - if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): - raise ValueError( - f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" - ) - - autoencode = autoencoder_model.encode_stage_2_inputs - latents = autoencode(inputs) * self.scale_factor - - if self.ldm_latent_shape is not None: - latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) - - get_likelihood = super().get_likelihood - - outputs = get_likelihood( - inputs=latents, - diffusion_model=diffusion_model, - scheduler=scheduler, - save_intermediates=save_intermediates, - conditioning=conditioning, - mode=mode, - verbose=verbose, - ) - - if save_intermediates and resample_latent_likelihoods: - intermediates = outputs[1] - resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) - intermediates = [resizer(x) for x in intermediates] - outputs = (outputs[0], intermediates) - return outputs diff --git a/tests/test_diffusion_inferer_maisi.py b/tests/test_diffusion_inferer_maisi.py deleted file mode 100644 index 52e51294fb..0000000000 --- a/tests/test_diffusion_inferer_maisi.py +++ /dev/null @@ -1,269 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest -from unittest import skipUnless - -import torch -from generative.networks.schedulers import DDIMScheduler, DDPMScheduler -from parameterized import parameterized - -from monai.apps.generation.maisi.inferers.inferer_maisi import DiffusionInfererMaisi -from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi -from monai.utils import optional_import - -_, has_scipy = optional_import("scipy") -_, has_einops = optional_import("einops") - -TEST_CASES = [ - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_channels": [8], - "norm_num_groups": 8, - "attention_levels": [True], - "num_res_blocks": 1, - "num_head_channels": 8, - }, - (2, 1, 8, 8), - ], - [ - { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "num_channels": [8], - "norm_num_groups": 8, - "attention_levels": [True], - "num_res_blocks": 1, - "num_head_channels": 8, - }, - (2, 1, 8, 8, 8), - ], -] - - -class TestDiffusionInfererMaisi(unittest.TestCase): - @parameterized.expand(TEST_CASES) - @skipUnless(has_einops, "Requires einops") - def test_call(self, model_params, input_shape): - model = DiffusionModelUNetMaisi(**model_params) - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - input = torch.randn(input_shape).to(device) - noise = torch.randn(input_shape).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = DiffusionInfererMaisi(scheduler=scheduler) - scheduler.set_timesteps(num_inference_steps=10) - timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() - top_region_index_tensor = torch.rand((1, 4)).to(device) - bottom_region_index_tensor = torch.rand((1, 4)).to(device) - spacing_tensor = torch.rand((1, 3)).to(device) - sample = inferer( - inputs=input, - noise=noise, - diffusion_model=model, - timesteps=timesteps, - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - ) - self.assertEqual(sample.shape, input_shape) - - @parameterized.expand(TEST_CASES) - @skipUnless(has_einops, "Requires einops") - def test_sample_intermediates(self, model_params, input_shape): - model = DiffusionModelUNetMaisi(**model_params) - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - noise = torch.randn(input_shape).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = DiffusionInfererMaisi(scheduler=scheduler) - scheduler.set_timesteps(num_inference_steps=10) - top_region_index_tensor = torch.rand((1, 4)).to(device) - bottom_region_index_tensor = torch.rand((1, 4)).to(device) - spacing_tensor = torch.rand((1, 3)).to(device) - sample, intermediates = inferer.sample( - input_noise=noise, - diffusion_model=model, - scheduler=scheduler, - save_intermediates=True, - intermediate_steps=1, - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - ) - self.assertEqual(len(intermediates), 10) - - @parameterized.expand(TEST_CASES) - @skipUnless(has_einops, "Requires einops") - def test_ddpm_sampler(self, model_params, input_shape): - model = DiffusionModelUNetMaisi(**model_params) - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - noise = torch.randn(input_shape).to(device) - scheduler = DDPMScheduler(num_train_timesteps=1000) - inferer = DiffusionInfererMaisi(scheduler=scheduler) - scheduler.set_timesteps(num_inference_steps=10) - top_region_index_tensor = torch.rand((1, 4)).to(device) - bottom_region_index_tensor = torch.rand((1, 4)).to(device) - spacing_tensor = torch.rand((1, 3)).to(device) - sample, intermediates = inferer.sample( - input_noise=noise, - diffusion_model=model, - scheduler=scheduler, - save_intermediates=True, - intermediate_steps=1, - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - ) - self.assertEqual(len(intermediates), 10) - - @parameterized.expand(TEST_CASES) - @skipUnless(has_einops, "Requires einops") - def test_ddim_sampler(self, model_params, input_shape): - model = DiffusionModelUNetMaisi(**model_params) - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - noise = torch.randn(input_shape).to(device) - scheduler = DDIMScheduler(num_train_timesteps=1000) - inferer = DiffusionInfererMaisi(scheduler=scheduler) - scheduler.set_timesteps(num_inference_steps=10) - top_region_index_tensor = torch.rand((1, 4)).to(device) - bottom_region_index_tensor = torch.rand((1, 4)).to(device) - spacing_tensor = torch.rand((1, 3)).to(device) - sample, intermediates = inferer.sample( - input_noise=noise, - diffusion_model=model, - scheduler=scheduler, - save_intermediates=True, - intermediate_steps=1, - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - ) - self.assertEqual(len(intermediates), 10) - - @parameterized.expand(TEST_CASES) - @skipUnless(has_einops, "Requires einops") - def test_sampler_conditioned(self, model_params, input_shape): - model_params["with_conditioning"] = True - model_params["cross_attention_dim"] = 3 - model = DiffusionModelUNetMaisi(**model_params) - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - noise = torch.randn(input_shape).to(device) - scheduler = DDIMScheduler(num_train_timesteps=1000) - inferer = DiffusionInfererMaisi(scheduler=scheduler) - scheduler.set_timesteps(num_inference_steps=10) - conditioning = torch.randn([input_shape[0], 1, 3]).to(device) - top_region_index_tensor = torch.rand((1, 4)).to(device) - bottom_region_index_tensor = torch.rand((1, 4)).to(device) - spacing_tensor = torch.rand((1, 3)).to(device) - sample, intermediates = inferer.sample( - input_noise=noise, - diffusion_model=model, - scheduler=scheduler, - save_intermediates=True, - intermediate_steps=1, - conditioning=conditioning, - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - ) - self.assertEqual(len(intermediates), 10) - - @parameterized.expand(TEST_CASES) - @skipUnless(has_einops, "Requires einops") - def test_sampler_conditioned_concat(self, model_params, input_shape): - model_params = model_params.copy() - n_concat_channel = 2 - model_params["in_channels"] = model_params["in_channels"] + n_concat_channel - model_params["cross_attention_dim"] = None - model_params["with_conditioning"] = False - model = DiffusionModelUNetMaisi(**model_params) - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - noise = torch.randn(input_shape).to(device) - conditioning_shape = list(input_shape) - conditioning_shape[1] = n_concat_channel - conditioning = torch.randn(conditioning_shape).to(device) - scheduler = DDIMScheduler(num_train_timesteps=1000) - inferer = DiffusionInfererMaisi(scheduler=scheduler) - scheduler.set_timesteps(num_inference_steps=10) - top_region_index_tensor = torch.rand((1, 4)).to(device) - bottom_region_index_tensor = torch.rand((1, 4)).to(device) - spacing_tensor = torch.rand((1, 3)).to(device) - sample, intermediates = inferer.sample( - input_noise=noise, - diffusion_model=model, - scheduler=scheduler, - save_intermediates=True, - intermediate_steps=1, - conditioning=conditioning, - mode="concat", - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - ) - self.assertEqual(len(intermediates), 10) - - @parameterized.expand(TEST_CASES) - @skipUnless(has_einops, "Requires einops") - def test_call_conditioned_concat(self, model_params, input_shape): - model_params = model_params.copy() - n_concat_channel = 2 - model_params["in_channels"] = model_params["in_channels"] + n_concat_channel - model_params["cross_attention_dim"] = None - model_params["with_conditioning"] = False - model = DiffusionModelUNetMaisi(**model_params) - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - input = torch.randn(input_shape).to(device) - noise = torch.randn(input_shape).to(device) - conditioning_shape = list(input_shape) - conditioning_shape[1] = n_concat_channel - conditioning = torch.randn(conditioning_shape).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = DiffusionInfererMaisi(scheduler=scheduler) - scheduler.set_timesteps(num_inference_steps=10) - timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() - top_region_index_tensor = torch.rand((1, 4)).to(device) - bottom_region_index_tensor = torch.rand((1, 4)).to(device) - spacing_tensor = torch.rand((1, 3)).to(device) - sample = inferer( - inputs=input, - noise=noise, - diffusion_model=model, - timesteps=timesteps, - condition=conditioning, - mode="concat", - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - ) - self.assertEqual(sample.shape, input_shape) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_latent_diffusion_inferer_maisi.py b/tests/test_latent_diffusion_inferer_maisi.py deleted file mode 100644 index 926564ea0f..0000000000 --- a/tests/test_latent_diffusion_inferer_maisi.py +++ /dev/null @@ -1,367 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest -from unittest import skipUnless - -import torch -from generative.networks.nets import AutoencoderKL -from generative.networks.schedulers import DDPMScheduler -from parameterized import parameterized - -from monai.apps.generation.maisi.inferers.inferer_maisi import LatentDiffusionInfererMaisi -from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi -from monai.utils import optional_import - -_, has_einops = optional_import("einops") -TEST_CASES = [ - [ - "AutoencoderKL", - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_channels": (4, 4), - "latent_channels": 3, - "attention_levels": [False, False], - "num_res_blocks": 1, - "with_encoder_nonlocal_attn": False, - "with_decoder_nonlocal_attn": False, - "norm_num_groups": 4, - }, - "DiffusionModelUNetMaisi", - { - "spatial_dims": 2, - "in_channels": 3, - "out_channels": 3, - "num_channels": [4, 4], - "norm_num_groups": 4, - "attention_levels": [False, False], - "num_res_blocks": 1, - "num_head_channels": 4, - }, - (1, 1, 8, 8), - (1, 3, 4, 4), - ] -] - - -class TestLatentDiffusionInfererMaisi(unittest.TestCase): - @parameterized.expand(TEST_CASES) - @skipUnless(has_einops, "Requires einops") - def test_prediction_shape( - self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape - ): - if ae_model_type == "AutoencoderKL": - stage_1 = AutoencoderKL(**autoencoder_params) - else: - raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") - - stage_2 = DiffusionModelUNetMaisi(**stage_2_params) if dm_model_type == "DiffusionModelUNetMaisi" else None - if stage_2 is None: - raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") - - device = "cuda:0" if torch.cuda.is_available() else "cpu" - stage_1.to(device) - stage_2.to(device) - stage_1.eval() - stage_2.eval() - - input = torch.randn(input_shape).to(device) - noise = torch.randn(latent_shape).to(device) - top_region_index_tensor = torch.rand((1, 4)).to(device) - bottom_region_index_tensor = torch.rand((1, 4)).to(device) - spacing_tensor = torch.rand((1, 3)).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) - scheduler.set_timesteps(num_inference_steps=10) - timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() - - prediction = inferer( - inputs=input, - autoencoder_model=stage_1, - diffusion_model=stage_2, - noise=noise, - timesteps=timesteps, - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - ) - self.assertEqual(prediction.shape, latent_shape) - - @parameterized.expand(TEST_CASES) - @skipUnless(has_einops, "Requires einops") - def test_sample_shape( - self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape - ): - if ae_model_type == "AutoencoderKL": - stage_1 = AutoencoderKL(**autoencoder_params) - else: - raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") - - stage_2 = DiffusionModelUNetMaisi(**stage_2_params) if dm_model_type == "DiffusionModelUNetMaisi" else None - if stage_2 is None: - raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") - - device = "cuda:0" if torch.cuda.is_available() else "cpu" - stage_1.to(device) - stage_2.to(device) - stage_1.eval() - stage_2.eval() - - noise = torch.randn(latent_shape).to(device) - top_region_index_tensor = torch.rand((1, 4)).to(device) - bottom_region_index_tensor = torch.rand((1, 4)).to(device) - spacing_tensor = torch.rand((1, 3)).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) - scheduler.set_timesteps(num_inference_steps=10) - - sample = inferer.sample( - input_noise=noise, - autoencoder_model=stage_1, - diffusion_model=stage_2, - scheduler=scheduler, - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - ) - self.assertEqual(sample.shape, input_shape) - - @parameterized.expand(TEST_CASES) - @skipUnless(has_einops, "Requires einops") - def test_sample_intermediates( - self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape - ): - if ae_model_type == "AutoencoderKL": - stage_1 = AutoencoderKL(**autoencoder_params) - else: - raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") - - stage_2 = DiffusionModelUNetMaisi(**stage_2_params) if dm_model_type == "DiffusionModelUNetMaisi" else None - if stage_2 is None: - raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") - - device = "cuda:0" if torch.cuda.is_available() else "cpu" - stage_1.to(device) - stage_2.to(device) - stage_1.eval() - stage_2.eval() - - noise = torch.randn(latent_shape).to(device) - top_region_index_tensor = torch.rand((1, 4)).to(device) - bottom_region_index_tensor = torch.rand((1, 4)).to(device) - spacing_tensor = torch.rand((1, 3)).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) - scheduler.set_timesteps(num_inference_steps=10) - - sample, intermediates = inferer.sample( - input_noise=noise, - autoencoder_model=stage_1, - diffusion_model=stage_2, - scheduler=scheduler, - save_intermediates=True, - intermediate_steps=1, - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - ) - self.assertEqual(len(intermediates), 10) - self.assertEqual(intermediates[0].shape, input_shape) - - @parameterized.expand(TEST_CASES) - @skipUnless(has_einops, "Requires einops") - def test_get_likelihoods( - self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape - ): - if ae_model_type == "AutoencoderKL": - stage_1 = AutoencoderKL(**autoencoder_params) - else: - raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") - - stage_2 = DiffusionModelUNetMaisi(**stage_2_params) if dm_model_type == "DiffusionModelUNetMaisi" else None - if stage_2 is None: - raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") - - device = "cuda:0" if torch.cuda.is_available() else "cpu" - stage_1.to(device) - stage_2.to(device) - stage_1.eval() - stage_2.eval() - - input = torch.randn(input_shape).to(device) - top_region_index_tensor = torch.rand((1, 4)).to(device) - bottom_region_index_tensor = torch.rand((1, 4)).to(device) - spacing_tensor = torch.rand((1, 3)).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) - scheduler.set_timesteps(num_inference_steps=10) - - sample, intermediates = inferer.get_likelihood( - inputs=input, - autoencoder_model=stage_1, - diffusion_model=stage_2, - scheduler=scheduler, - save_intermediates=True, - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - ) - self.assertEqual(len(intermediates), 10) - self.assertEqual(intermediates[0].shape, latent_shape) - - @parameterized.expand(TEST_CASES) - @skipUnless(has_einops, "Requires einops") - def test_resample_likelihoods( - self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape - ): - if ae_model_type == "AutoencoderKL": - stage_1 = AutoencoderKL(**autoencoder_params) - else: - raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") - - stage_2 = DiffusionModelUNetMaisi(**stage_2_params) if dm_model_type == "DiffusionModelUNetMaisi" else None - if stage_2 is None: - raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") - - device = "cuda:0" if torch.cuda.is_available() else "cpu" - stage_1.to(device) - stage_2.to(device) - stage_1.eval() - stage_2.eval() - - input = torch.randn(input_shape).to(device) - top_region_index_tensor = torch.rand((1, 4)).to(device) - bottom_region_index_tensor = torch.rand((1, 4)).to(device) - spacing_tensor = torch.rand((1, 3)).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) - scheduler.set_timesteps(num_inference_steps=10) - - sample, intermediates = inferer.get_likelihood( - inputs=input, - autoencoder_model=stage_1, - diffusion_model=stage_2, - scheduler=scheduler, - save_intermediates=True, - resample_latent_likelihoods=True, - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - ) - self.assertEqual(len(intermediates), 10) - self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) - - @parameterized.expand(TEST_CASES) - @skipUnless(has_einops, "Requires einops") - def test_prediction_shape_conditioned_concat( - self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape - ): - if ae_model_type == "AutoencoderKL": - stage_1 = AutoencoderKL(**autoencoder_params) - else: - raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") - - stage_2_params = stage_2_params.copy() - n_concat_channel = 3 - stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel - stage_2 = DiffusionModelUNetMaisi(**stage_2_params) if dm_model_type == "DiffusionModelUNetMaisi" else None - if stage_2 is None: - raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") - - device = "cuda:0" if torch.cuda.is_available() else "cpu" - stage_1.to(device) - stage_2.to(device) - stage_1.eval() - stage_2.eval() - - input = torch.randn(input_shape).to(device) - noise = torch.randn(latent_shape).to(device) - conditioning_shape = list(latent_shape) - conditioning_shape[1] = n_concat_channel - conditioning = torch.randn(conditioning_shape).to(device) - top_region_index_tensor = torch.rand((1, 4)).to(device) - bottom_region_index_tensor = torch.rand((1, 4)).to(device) - spacing_tensor = torch.rand((1, 3)).to(device) - - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) - scheduler.set_timesteps(num_inference_steps=10) - timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() - - prediction = inferer( - inputs=input, - autoencoder_model=stage_1, - diffusion_model=stage_2, - noise=noise, - timesteps=timesteps, - condition=conditioning, - mode="concat", - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - ) - self.assertEqual(prediction.shape, latent_shape) - - @parameterized.expand(TEST_CASES) - @skipUnless(has_einops, "Requires einops") - def test_sample_shape_conditioned_concat( - self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape - ): - if ae_model_type == "AutoencoderKL": - stage_1 = AutoencoderKL(**autoencoder_params) - else: - raise ValueError(f"Unsupported autoencoder model type: {ae_model_type}") - - stage_2_params = stage_2_params.copy() - n_concat_channel = 3 - stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel - stage_2 = DiffusionModelUNetMaisi(**stage_2_params) if dm_model_type == "DiffusionModelUNetMaisi" else None - if stage_2 is None: - raise ValueError(f"Unsupported diffusion model type: {dm_model_type}") - - device = "cuda:0" if torch.cuda.is_available() else "cpu" - stage_1.to(device) - stage_2.to(device) - stage_1.eval() - stage_2.eval() - - noise = torch.randn(latent_shape).to(device) - conditioning_shape = list(latent_shape) - conditioning_shape[1] = n_concat_channel - conditioning = torch.randn(conditioning_shape).to(device) - top_region_index_tensor = torch.rand((1, 4)).to(device) - bottom_region_index_tensor = torch.rand((1, 4)).to(device) - spacing_tensor = torch.rand((1, 3)).to(device) - - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = LatentDiffusionInfererMaisi(scheduler=scheduler, scale_factor=1.0) - scheduler.set_timesteps(num_inference_steps=10) - - sample = inferer.sample( - input_noise=noise, - autoencoder_model=stage_1, - diffusion_model=stage_2, - scheduler=scheduler, - conditioning=conditioning, - mode="concat", - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - ) - self.assertEqual(sample.shape, input_shape) - - -if __name__ == "__main__": - unittest.main() From 90d39615f3fd3cd20507e0dc3df9bc257e16e5ca Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Sat, 22 Jun 2024 15:35:21 +0000 Subject: [PATCH 25/36] update unet Signed-off-by: dongyang0122 --- .../maisi/networks/diffusion_model_unet_maisi.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index 520d8596d6..68d559065c 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -379,12 +379,12 @@ def forward( # Additional residual conections for Controlnets if down_block_additional_residuals is not None: - new_down_block_res_samples = () + new_down_block_res_samples: list[torch.Tensor] = [] for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual - new_down_block_res_samples += (down_block_res_sample,) + new_down_block_res_samples.append(down_block_res_sample) down_block_res_samples = new_down_block_res_samples @@ -404,4 +404,6 @@ def forward( # 7. output block h = self.out(h) - return h + if isinstance(h, torch.Tensor): + return h + return torch.tensor(h) From eb2ca47a4a01359562f1ab5409d86019ac8c7188 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Mon, 24 Jun 2024 21:02:08 +0000 Subject: [PATCH 26/36] update Signed-off-by: dongyang0122 --- monai/apps/generation/maisi/__init__.py | 21 ------------------ .../networks/diffusion_model_unet_maisi.py | 22 ++++++++++++------- requirements-dev.txt | 1 + tests/test_diffusion_model_unet_maisi.py | 8 ++++++- 4 files changed, 22 insertions(+), 30 deletions(-) diff --git a/monai/apps/generation/maisi/__init__.py b/monai/apps/generation/maisi/__init__.py index ef42d42730..1e97f89407 100644 --- a/monai/apps/generation/maisi/__init__.py +++ b/monai/apps/generation/maisi/__init__.py @@ -8,24 +8,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import annotations - -import subprocess -import sys - - -def install_and_import(package, package_fullname=None): - if package_fullname is None: - package_fullname = package - - try: - __import__(package) - except ImportError: - print(f"'{package}' is not installed. Installing now...") - subprocess.check_call([sys.executable, "-m", "pip", "install", package_fullname]) - print(f"'{package}' installation completed.") - __import__(package) - - -install_and_import("generative", "monai-generative") diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index 68d559065c..ddb9b1a8b0 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -35,17 +35,22 @@ from collections.abc import Sequence import torch -from generative.networks.nets.diffusion_model_unet import ( - get_down_block, - get_mid_block, - get_timestep_embedding, - get_up_block, - zero_module, -) from torch import nn from monai.networks.blocks import Convolution -from monai.utils import ensure_tuple_rep +from monai.utils import ensure_tuple_rep, optional_import + +get_down_block, has_get_down_block = optional_import( + "generative.networks.nets.diffusion_model_unet", name="get_down_block" +) +get_mid_block, has_get_mid_block = optional_import( + "generative.networks.nets.diffusion_model_unet", name="get_mid_block" +) +get_timestep_embedding, has_get_timestep_embedding = optional_import( + "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding" +) +get_up_block, has_get_up_block = optional_import("generative.networks.nets.diffusion_model_unet", name="get_up_block") +zero_module, has_zero_module = optional_import("generative.networks.nets.diffusion_model_unet", name="zero_module") if importlib.util.find_spec("xformers") is not None: import xformers @@ -56,6 +61,7 @@ xformers = None has_xformers = False + __all__ = ["DiffusionModelUNetMaisi"] diff --git a/requirements-dev.txt b/requirements-dev.txt index a8ba25966b..37e5917c6a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -57,3 +57,4 @@ zarr lpips==0.1.4 nvidia-ml-py huggingface_hub +monai-generative diff --git a/tests/test_diffusion_model_unet_maisi.py b/tests/test_diffusion_model_unet_maisi.py index 069bb68879..3d6ace3b79 100644 --- a/tests/test_diffusion_model_unet_maisi.py +++ b/tests/test_diffusion_model_unet_maisi.py @@ -17,11 +17,15 @@ import torch from parameterized import parameterized -from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi from monai.networks import eval_mode from monai.utils import optional_import _, has_einops = optional_import("einops") +_, has_generative = optional_import("generative") + +if has_generative: + from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi + UNCOND_CASES_2D = [ [ @@ -288,6 +292,7 @@ ] +@skipUnless(has_generative, "monai-generative required") class TestDiffusionModelUNetMaisi2D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_2D) @skipUnless(has_einops, "Requires einops") @@ -505,6 +510,7 @@ def test_shape_with_additional_inputs(self, input_param): self.assertEqual(result.shape, (1, 1, 16, 16)) +@skipUnless(has_generative, "monai-generative required") class TestDiffusionModelUNetMaisi3D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_3D) @skipUnless(has_einops, "Requires einops") From 867d0cd0ea5c906b5793eea33111b9ca9166ca1b Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Fri, 28 Jun 2024 18:00:48 +0000 Subject: [PATCH 27/36] update xformers Signed-off-by: dongyang0122 --- .../maisi/networks/diffusion_model_unet_maisi.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index ddb9b1a8b0..225e5016a4 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -31,7 +31,6 @@ from __future__ import annotations -import importlib.util from collections.abc import Sequence import torch @@ -50,18 +49,9 @@ "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding" ) get_up_block, has_get_up_block = optional_import("generative.networks.nets.diffusion_model_unet", name="get_up_block") +xformers, has_xformers = optional_import("xformers") zero_module, has_zero_module = optional_import("generative.networks.nets.diffusion_model_unet", name="zero_module") -if importlib.util.find_spec("xformers") is not None: - import xformers - import xformers.ops - - has_xformers = True -else: - xformers = None - has_xformers = False - - __all__ = ["DiffusionModelUNetMaisi"] From 38d41c42b01e100aa8a583205beed31dd0e15507 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Fri, 28 Jun 2024 18:12:15 +0000 Subject: [PATCH 28/36] update print messages Signed-off-by: dongyang0122 --- .../maisi/networks/diffusion_model_unet_maisi.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index 225e5016a4..a763cc6616 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -110,22 +110,28 @@ def __init__( super().__init__() if with_conditioning is True and cross_attention_dim is None: raise ValueError( - "CustomDiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "DiffusionModelUNetMaisi expects dimension of the cross-attention conditioning (cross_attention_dim) " "when using with_conditioning." ) if cross_attention_dim is not None and with_conditioning is False: raise ValueError( - "CustomDiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + "DiffusionModelUNetMaisi expects with_conditioning=True when specifying the cross_attention_dim." ) if dropout_cattn > 1.0 or dropout_cattn < 0.0: raise ValueError("Dropout cannot be negative or >1.0!") # All number of channels should be multiple of num_groups if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): - raise ValueError("CustomDiffusionModelUNet expects all num_channels being multiple of norm_num_groups") + raise ValueError( + f"DiffusionModelUNetMaisi expects all num_channels being multiple of norm_num_groups, " + f"but get num_channels: {num_channels} and norm_num_groups: {norm_num_groups}" + ) if len(num_channels) != len(attention_levels): - raise ValueError("CustomDiffusionModelUNet expects num_channels being same size of attention_levels") + raise ValueError( + f"DiffusionModelUNetMaisi expects num_channels being same size of attention_levels, " + f"but get num_channels: {len(num_channels)} and attention_levels: {len(attention_levels)}" + ) if isinstance(num_head_channels, int): num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) From 8ad8ef8fa3f054d00b3b9228cf9f9ee76a85dcf8 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Fri, 28 Jun 2024 18:22:48 +0000 Subject: [PATCH 29/36] update tensor formating Signed-off-by: dongyang0122 --- .../maisi/networks/diffusion_model_unet_maisi.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index a763cc6616..1a0c7bc533 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -38,6 +38,7 @@ from monai.networks.blocks import Convolution from monai.utils import ensure_tuple_rep, optional_import +from monai.utils.type_conversion import convert_to_tensor get_down_block, has_get_down_block = optional_import( "generative.networks.nets.diffusion_model_unet", name="get_down_block" @@ -405,7 +406,5 @@ def forward( # 7. output block h = self.out(h) - - if isinstance(h, torch.Tensor): - return h - return torch.tensor(h) + h_tensor: torch.Tensor = convert_to_tensor(h) + return h_tensor From 473ef4be313a65fbf40c9e405b8b6c8c49fb9994 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Fri, 28 Jun 2024 18:43:49 +0000 Subject: [PATCH 30/36] refactor forward function Signed-off-by: dongyang0122 --- .../networks/diffusion_model_unet_maisi.py | 101 +++++++++--------- 1 file changed, 53 insertions(+), 48 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index 1a0c7bc533..bf9e8f0ac4 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -311,36 +311,7 @@ def __init__( ), ) - def forward( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - context: torch.Tensor | None = None, - class_labels: torch.Tensor | None = None, - down_block_additional_residuals: tuple[torch.Tensor] | None = None, - mid_block_additional_residual: torch.Tensor | None = None, - top_region_index_tensor: torch.Tensor | None = None, - bottom_region_index_tensor: torch.Tensor | None = None, - spacing_tensor: torch.Tensor | None = None, - ) -> torch.Tensor: - """ - Forward pass through the UNet model. - - Args: - x: Input tensor of shape (N, C, SpatialDims). - timesteps: Timestep tensor of shape (N,). - context: Context tensor of shape (N, 1, ContextDim). - class_labels: Class labels tensor of shape (N,). - down_block_additional_residuals: Additional residual tensors for down blocks of shape (N, C, FeatureMapsDims). - mid_block_additional_residual: Additional residual tensor for mid block of shape (N, C, FeatureMapsDims). - top_region_index_tensor: Tensor representing top region index of shape (N, 4). - bottom_region_index_tensor: Tensor representing bottom region index of shape (N, 4). - spacing_tensor: Tensor representing spacing of shape (N, 3). - - Returns: - A tensor representing the output of the UNet model. - """ - # 1. time + def _get_time_and_class_embedding(self, x, timesteps, class_labels): t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) # timesteps does not contain any weights and will always return f32 tensors @@ -349,29 +320,27 @@ def forward( t_emb = t_emb.to(dtype=x.dtype) emb = self.time_embed(t_emb) - # 2. class if self.num_class_embeds is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") class_emb = self.class_embedding(class_labels) class_emb = class_emb.to(dtype=x.dtype) emb = emb + class_emb + return emb - # 3. input + def _get_input_embeddings(self, emb, top_index, bottom_index, spacing): if self.input_top_region_index: - _emb = self.top_region_index_layer(top_region_index_tensor) + _emb = self.top_region_index_layer(top_index) emb = torch.cat((emb, _emb), dim=1) if self.input_bottom_region_index: - _emb = self.bottom_region_index_layer(bottom_region_index_tensor) + _emb = self.bottom_region_index_layer(bottom_index) emb = torch.cat((emb, _emb), dim=1) if self.input_spacing: - _emb = self.spacing_layer(spacing_tensor) + _emb = self.spacing_layer(spacing) emb = torch.cat((emb, _emb), dim=1) + return emb - # 3. initial convolution - h = self.conv_in(x) - - # 4. down + def _apply_down_blocks(self, h, emb, context, down_block_additional_residuals): if context is not None and self.with_conditioning is False: raise ValueError("model should have with_conditioning = True if context is provided") down_block_res_samples: list[torch.Tensor] = [h] @@ -390,21 +359,57 @@ def forward( new_down_block_res_samples.append(down_block_res_sample) down_block_res_samples = new_down_block_res_samples + return h, down_block_res_samples - # 5. mid - h = self.middle_block(hidden_states=h, temb=emb, context=context) - - # Additional residual conections for Controlnets - if mid_block_additional_residual is not None: - h = h + mid_block_additional_residual - - # 6. up + def _apply_up_blocks(self, h, emb, context, down_block_res_samples): for upsample_block in self.up_blocks: res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) - # 7. output block + return h + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + down_block_additional_residuals: tuple[torch.Tensor] | None = None, + mid_block_additional_residual: torch.Tensor | None = None, + top_region_index_tensor: torch.Tensor | None = None, + bottom_region_index_tensor: torch.Tensor | None = None, + spacing_tensor: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Forward pass through the UNet model. + + Args: + x: Input tensor of shape (N, C, SpatialDims). + timesteps: Timestep tensor of shape (N,). + context: Context tensor of shape (N, 1, ContextDim). + class_labels: Class labels tensor of shape (N,). + down_block_additional_residuals: Additional residual tensors for down blocks of shape (N, C, FeatureMapsDims). + mid_block_additional_residual: Additional residual tensor for mid block of shape (N, C, FeatureMapsDims). + top_region_index_tensor: Tensor representing top region index of shape (N, 4). + bottom_region_index_tensor: Tensor representing bottom region index of shape (N, 4). + spacing_tensor: Tensor representing spacing of shape (N, 3). + + Returns: + A tensor representing the output of the UNet model. + """ + + emb = self._get_time_and_class_embedding(x, timesteps, class_labels) + emb = self._get_input_embeddings(emb, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor) + h = self.conv_in(x) + h, _updated_down_block_res_samples = self._apply_down_blocks(h, emb, context, down_block_additional_residuals) + h = self.middle_block(h, emb, context) + + # Additional residual conections for Controlnets + if mid_block_additional_residual is not None: + h = h + mid_block_additional_residual + + h = self._apply_up_blocks(h, emb, context, _updated_down_block_res_samples) h = self.out(h) h_tensor: torch.Tensor = convert_to_tensor(h) return h_tensor From 4a53fc58d76f98adce3fb3db4e74fc79fb2a3616 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Fri, 28 Jun 2024 19:07:28 +0000 Subject: [PATCH 31/36] refactor embedding modules Signed-off-by: dongyang0122 --- .../maisi/networks/diffusion_model_unet_maisi.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index bf9e8f0ac4..a267d902da 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -196,19 +196,13 @@ def __init__( new_time_embed_dim = time_embed_dim if self.input_top_region_index: - self.top_region_index_layer = nn.Sequential( - nn.Linear(4, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) - ) + self.top_region_index_layer = self._create_embedding_module(4, time_embed_dim) new_time_embed_dim += time_embed_dim if self.input_bottom_region_index: - self.bottom_region_index_layer = nn.Sequential( - nn.Linear(4, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) - ) + self.bottom_region_index_layer = self._create_embedding_module(4, time_embed_dim) new_time_embed_dim += time_embed_dim if self.input_spacing: - self.spacing_layer = nn.Sequential( - nn.Linear(3, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) - ) + self.spacing_layer = self._create_embedding_module(3, time_embed_dim) new_time_embed_dim += time_embed_dim # down @@ -311,6 +305,10 @@ def __init__( ), ) + def _create_embedding_module(self, input_dim, embed_dim): + model = nn.Sequential(nn.Linear(input_dim, embed_dim), nn.SiLU(), nn.Linear(embed_dim, embed_dim)) + return model + def _get_time_and_class_embedding(self, x, timesteps, class_labels): t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) From fc3e828a4e92320d5c8391384bdfb2303e30db83 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Fri, 28 Jun 2024 19:11:21 +0000 Subject: [PATCH 32/36] refactor embedding modules Signed-off-by: dongyang0122 --- .../generation/maisi/networks/diffusion_model_unet_maisi.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index a267d902da..33dafb7f80 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -181,9 +181,7 @@ def __init__( # time time_embed_dim = num_channels[0] * 4 - self.time_embed = nn.Sequential( - nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) - ) + self.time_embed = self._create_embedding_module(num_channels[0], time_embed_dim) # class embedding self.num_class_embeds = num_class_embeds From 5ea43b978deaad15f21b1daca6f7789fc2a53125 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Mon, 1 Jul 2024 13:13:43 +0000 Subject: [PATCH 33/36] refactor embedding modules Signed-off-by: dongyang0122 --- .../generation/maisi/networks/diffusion_model_unet_maisi.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index 33dafb7f80..abbbae3e4d 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -321,7 +321,7 @@ def _get_time_and_class_embedding(self, x, timesteps, class_labels): raise ValueError("class_labels should be provided when num_class_embeds > 0") class_emb = self.class_embedding(class_labels) class_emb = class_emb.to(dtype=x.dtype) - emb = emb + class_emb + emb += class_emb return emb def _get_input_embeddings(self, emb, top_index, bottom_index, spacing): @@ -351,7 +351,7 @@ def _apply_down_blocks(self, h, emb, context, down_block_additional_residuals): for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): - down_block_res_sample = down_block_res_sample + down_block_additional_residual + down_block_res_sample += down_block_additional_residual new_down_block_res_samples.append(down_block_res_sample) down_block_res_samples = new_down_block_res_samples @@ -403,7 +403,7 @@ def forward( # Additional residual conections for Controlnets if mid_block_additional_residual is not None: - h = h + mid_block_additional_residual + h += mid_block_additional_residual h = self._apply_up_blocks(h, emb, context, _updated_down_block_res_samples) h = self.out(h) From 06852b8d5e4549f924ed4b797eab06ef1270ace2 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Mon, 1 Jul 2024 13:21:48 +0000 Subject: [PATCH 34/36] update Signed-off-by: dongyang0122 --- .../networks/diffusion_model_unet_maisi.py | 30 +++++++++---------- tests/test_diffusion_model_unet_maisi.py | 12 ++++---- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index abbbae3e4d..fba13db9b1 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -80,9 +80,9 @@ class DiffusionModelUNetMaisi(nn.Module): upcast_attention: If True, upcast attention operations to full precision. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. dropout_cattn: If different from zero, this will be the dropout value for the cross-attention layers. - input_top_region_index: If True, use top region index input. - input_bottom_region_index: If True, use bottom region index input. - input_spacing: If True, use spacing input. + include_top_region_index_input: If True, use top region index input. + include_bottom_region_index_input: If True, use bottom region index input. + include_spacing_input: If True, use spacing input. """ def __init__( @@ -104,9 +104,9 @@ def __init__( upcast_attention: bool = False, use_flash_attention: bool = False, dropout_cattn: float = 0.0, - input_top_region_index: bool = False, - input_bottom_region_index: bool = False, - input_spacing: bool = False, + include_top_region_index_input: bool = False, + include_bottom_region_index_input: bool = False, + include_spacing_input: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -188,18 +188,18 @@ def __init__( if num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - self.input_top_region_index = input_top_region_index - self.input_bottom_region_index = input_bottom_region_index - self.input_spacing = input_spacing + self.include_top_region_index_input = include_top_region_index_input + self.include_bottom_region_index_input = include_bottom_region_index_input + self.include_spacing_input = include_spacing_input new_time_embed_dim = time_embed_dim - if self.input_top_region_index: + if self.include_top_region_index_input: self.top_region_index_layer = self._create_embedding_module(4, time_embed_dim) new_time_embed_dim += time_embed_dim - if self.input_bottom_region_index: + if self.include_bottom_region_index_input: self.bottom_region_index_layer = self._create_embedding_module(4, time_embed_dim) new_time_embed_dim += time_embed_dim - if self.input_spacing: + if self.include_spacing_input: self.spacing_layer = self._create_embedding_module(3, time_embed_dim) new_time_embed_dim += time_embed_dim @@ -325,13 +325,13 @@ def _get_time_and_class_embedding(self, x, timesteps, class_labels): return emb def _get_input_embeddings(self, emb, top_index, bottom_index, spacing): - if self.input_top_region_index: + if self.include_top_region_index_input: _emb = self.top_region_index_layer(top_index) emb = torch.cat((emb, _emb), dim=1) - if self.input_bottom_region_index: + if self.include_bottom_region_index_input: _emb = self.bottom_region_index_layer(bottom_index) emb = torch.cat((emb, _emb), dim=1) - if self.input_spacing: + if self.include_spacing_input: _emb = self.spacing_layer(spacing) emb = torch.cat((emb, _emb), dim=1) return emb diff --git a/tests/test_diffusion_model_unet_maisi.py b/tests/test_diffusion_model_unet_maisi.py index 3d6ace3b79..b5c14192d9 100644 --- a/tests/test_diffusion_model_unet_maisi.py +++ b/tests/test_diffusion_model_unet_maisi.py @@ -495,9 +495,9 @@ def test_conditioned_2d_models_shape(self, input_param): @parameterized.expand(UNCOND_CASES_2D) @skipUnless(has_einops, "Requires einops") def test_shape_with_additional_inputs(self, input_param): - input_param["input_top_region_index"] = True - input_param["input_bottom_region_index"] = True - input_param["input_spacing"] = True + input_param["include_top_region_index_input"] = True + input_param["include_bottom_region_index_input"] = True + input_param["include_spacing_input"] = True net = DiffusionModelUNetMaisi(**input_param) with eval_mode(net): result = net.forward( @@ -573,9 +573,9 @@ def test_right_dropout(self, input_param): @parameterized.expand(UNCOND_CASES_3D) @skipUnless(has_einops, "Requires einops") def test_shape_with_additional_inputs(self, input_param): - input_param["input_top_region_index"] = True - input_param["input_bottom_region_index"] = True - input_param["input_spacing"] = True + input_param["include_top_region_index_input"] = True + input_param["include_bottom_region_index_input"] = True + input_param["include_spacing_input"] = True net = DiffusionModelUNetMaisi(**input_param) with eval_mode(net): result = net.forward( From 757dda464d880c7324996d8202565efec40206c5 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Mon, 1 Jul 2024 13:26:40 +0000 Subject: [PATCH 35/36] update Signed-off-by: dongyang0122 --- .../generation/maisi/networks/diffusion_model_unet_maisi.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index fba13db9b1..d5f5f6136b 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -342,8 +342,7 @@ def _apply_down_blocks(self, h, emb, context, down_block_additional_residuals): down_block_res_samples: list[torch.Tensor] = [h] for downsample_block in self.down_blocks: h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) - for residual in res_samples: - down_block_res_samples.append(residual) + down_block_res_samples.extend(res_samples) # Additional residual conections for Controlnets if down_block_additional_residuals is not None: From 22582bd4f9f7626593c85d356f110ff17fdbad1d Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Mon, 1 Jul 2024 15:09:06 +0000 Subject: [PATCH 36/36] update Signed-off-by: dongyang0122 --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index b598f301f6..1bba930273 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -58,4 +58,4 @@ lpips==0.1.4 nvidia-ml-py huggingface_hub pyamg>=5.0.0 -git+https://github.com/KumoLiu/GenerativeModels.git@cuda#egg=monai-generative +git+https://github.com/Project-MONAI/GenerativeModels.git@7428fce193771e9564f29b91d29e523dd1b6b4cd