From 4ca2fd750bcbc57e84a30581183938e0db7d69c5 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Tue, 19 Sep 2023 16:47:52 +0100 Subject: [PATCH 1/2] Add pad and cropping options to the Latent Diffusion Inferer (+ test). --- generative/inferers/inferer.py | 29 +++++- tests/test_latent_diffusion_inferer.py | 117 +++++++++++++++++++++++++ 2 files changed, 145 insertions(+), 1 deletion(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 68b0bdd6..84a563fe 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -19,6 +19,7 @@ import torch.nn.functional as F from monai.inferers import Inferer from monai.utils import optional_import +from monai.transforms import SpatialPad, CenterSpatialCrop tqdm, has_tqdm = optional_import("tqdm", name="tqdm") @@ -303,11 +304,26 @@ class LatentDiffusionInferer(DiffusionInferer): scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents. scale_factor: scale factor to multiply the values of the latent representation before processing it by the second stage. + ldm_latent_shape: desired SPATIAL latent space shape. Used if there is a difference in output VAE latent shape + and LDM shape. + vae_latent_shape: VAE SPATIAL latent space shape. Used if there is a difference in output VAE latent shape and + LDM shape. """ - def __init__(self, scheduler: nn.Module, scale_factor: float = 1.0) -> None: + def __init__(self, scheduler: nn.Module, scale_factor: float = 1.0, + ldm_latent_shape: list | None = None, + vae_latent_shape: list | None = None) -> None: + super().__init__(scheduler=scheduler) self.scale_factor = scale_factor + if (ldm_latent_shape is None) ^ (vae_latent_shape is None): + raise ValueError("If ldm_latent_shape is None, vae_latent_shape must be None" + "and vice versa.") + self.ldm_latent_shape = ldm_latent_shape + self.vae_latent_shape = vae_latent_shape + if self.ldm_latent_shape is not None: + self.padder = SpatialPad(spatial_size=[-1,]+self.ldm_latent_shape) + self.cropper = CenterSpatialCrop(roi_size=[-1,]+self.vae_latent_shape) def __call__( self, @@ -334,6 +350,9 @@ def __call__( with torch.no_grad(): latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + if self.ldm_latent_shape is not None: + latent = self.padder(latent) + prediction = super().__call__( inputs=latent, diffusion_model=diffusion_model, @@ -386,6 +405,10 @@ def sample( else: latent = outputs + if self.ldm_latent_shape is not None: + latent = self.cropper(latent) + latent_intermediates = [self.cropper(l) for l in latent_intermediates] + image = autoencoder_model.decode_stage_2_outputs(latent / self.scale_factor) if save_intermediates: @@ -437,6 +460,10 @@ def get_likelihood( f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" ) latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if self.ldm_latent_shape is not None: + latents = self.padder(latents) + outputs = super().get_likelihood( inputs=latents, diffusion_model=diffusion_model, diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index 296e9266..2ed89e1a 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -103,6 +103,89 @@ (1, 3, 4, 4, 4), ], ] +TEST_CASES_DIFF_SHAPES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "num_channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "num_channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "num_channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 12, 12, 12), + (1, 3, 8, 8, 8), + ], +] class TestDiffusionSamplingInferer(unittest.TestCase): @@ -325,6 +408,40 @@ def test_sample_shape_conditioned_concat( ) self.assertEqual(sample.shape, input_shape) + @parameterized.expand(TEST_CASES_DIFF_SHAPES) + def test_sample_shape_different_latents(self, + model_type, + autoencoder_params, + stage_2_params, + input_shape, + latent_shape + ): + if model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + vae_latent_shape = [i//(2**(len(autoencoder_params['num_channels'])-1)) for i in input_shape[2:]] + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + vae_latent_shape=vae_latent_shape) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + prediction = inferer( + inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps + ) + self.assertEqual(prediction.shape, latent_shape) if __name__ == "__main__": unittest.main() From 0cd2e528ea5de5ede0ead08742ccc9ccd6db588d Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Mon, 23 Oct 2023 16:30:29 +0100 Subject: [PATCH 2/2] Added changes: mainly changing the argument descriptions, and substituting VAE > autoencoder for generalisation purposes, as well as padder and cropper by ldm_resizer and autoencoder_resizer. --- generative/inferers/inferer.py | 29 +++++++++++--------------- tests/test_latent_diffusion_inferer.py | 4 ++-- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 84a563fe..37229b49 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -23,7 +23,6 @@ tqdm, has_tqdm = optional_import("tqdm", name="tqdm") - class DiffusionInferer(Inferer): """ DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass @@ -294,7 +293,6 @@ def _get_decoder_log_likelihood( assert log_probs.shape == inputs.shape return log_probs - class LatentDiffusionInferer(DiffusionInferer): """ LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can @@ -304,26 +302,24 @@ class LatentDiffusionInferer(DiffusionInferer): scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents. scale_factor: scale factor to multiply the values of the latent representation before processing it by the second stage. - ldm_latent_shape: desired SPATIAL latent space shape. Used if there is a difference in output VAE latent shape - and LDM shape. - vae_latent_shape: VAE SPATIAL latent space shape. Used if there is a difference in output VAE latent shape and - LDM shape. + ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape. + autoencoder_latent_shape: autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a difference between the autoencoder's latent shape and the DM shape. """ def __init__(self, scheduler: nn.Module, scale_factor: float = 1.0, ldm_latent_shape: list | None = None, - vae_latent_shape: list | None = None) -> None: + autoencoder_latent_shape: list | None = None) -> None: super().__init__(scheduler=scheduler) self.scale_factor = scale_factor - if (ldm_latent_shape is None) ^ (vae_latent_shape is None): - raise ValueError("If ldm_latent_shape is None, vae_latent_shape must be None" + if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None): + raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None" "and vice versa.") self.ldm_latent_shape = ldm_latent_shape - self.vae_latent_shape = vae_latent_shape + self.autoencoder_latent_shape = autoencoder_latent_shape if self.ldm_latent_shape is not None: - self.padder = SpatialPad(spatial_size=[-1,]+self.ldm_latent_shape) - self.cropper = CenterSpatialCrop(roi_size=[-1,]+self.vae_latent_shape) + self.ldm_resizer = SpatialPad(spatial_size=[-1,]+self.ldm_latent_shape) + self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1,]+self.autoencoder_latent_shape) def __call__( self, @@ -351,7 +347,7 @@ def __call__( latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor if self.ldm_latent_shape is not None: - latent = self.padder(latent) + latent = self.ldm_resizer(latent) prediction = super().__call__( inputs=latent, @@ -406,8 +402,8 @@ def sample( latent = outputs if self.ldm_latent_shape is not None: - latent = self.cropper(latent) - latent_intermediates = [self.cropper(l) for l in latent_intermediates] + latent = self.autoencoder_resizer(latent) + latent_intermediates = [self.autoencoder_resizer(l) for l in latent_intermediates] image = autoencoder_model.decode_stage_2_outputs(latent / self.scale_factor) @@ -462,7 +458,7 @@ def get_likelihood( latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor if self.ldm_latent_shape is not None: - latents = self.padder(latents) + latents = self.ldm_resizer(latents) outputs = super().get_likelihood( inputs=latents, @@ -480,7 +476,6 @@ def get_likelihood( outputs = (outputs[0], intermediates) return outputs - class VQVAETransformerInferer(Inferer): """ Class to perform inference with a VQVAE + Transformer model. diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index 2ed89e1a..ba607c34 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -432,10 +432,10 @@ def test_sample_shape_different_latents(self, noise = torch.randn(latent_shape).to(device) scheduler = DDPMScheduler(num_train_timesteps=10) # We infer the VAE shape - vae_latent_shape = [i//(2**(len(autoencoder_params['num_channels'])-1)) for i in input_shape[2:]] + autoencoder_latent_shape = [i//(2**(len(autoencoder_params['num_channels'])-1)) for i in input_shape[2:]] inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0, ldm_latent_shape=list(latent_shape[2:]), - vae_latent_shape=vae_latent_shape) + autoencoder_latent_shape=autoencoder_latent_shape) scheduler.set_timesteps(num_inference_steps=10) timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()