diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 68b0bdd6..37229b49 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -19,10 +19,10 @@ import torch.nn.functional as F from monai.inferers import Inferer from monai.utils import optional_import +from monai.transforms import SpatialPad, CenterSpatialCrop tqdm, has_tqdm = optional_import("tqdm", name="tqdm") - class DiffusionInferer(Inferer): """ DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass @@ -293,7 +293,6 @@ def _get_decoder_log_likelihood( assert log_probs.shape == inputs.shape return log_probs - class LatentDiffusionInferer(DiffusionInferer): """ LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can @@ -303,11 +302,24 @@ class LatentDiffusionInferer(DiffusionInferer): scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents. scale_factor: scale factor to multiply the values of the latent representation before processing it by the second stage. + ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape. + autoencoder_latent_shape: autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a difference between the autoencoder's latent shape and the DM shape. """ - def __init__(self, scheduler: nn.Module, scale_factor: float = 1.0) -> None: + def __init__(self, scheduler: nn.Module, scale_factor: float = 1.0, + ldm_latent_shape: list | None = None, + autoencoder_latent_shape: list | None = None) -> None: + super().__init__(scheduler=scheduler) self.scale_factor = scale_factor + if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None): + raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None" + "and vice versa.") + self.ldm_latent_shape = ldm_latent_shape + self.autoencoder_latent_shape = autoencoder_latent_shape + if self.ldm_latent_shape is not None: + self.ldm_resizer = SpatialPad(spatial_size=[-1,]+self.ldm_latent_shape) + self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1,]+self.autoencoder_latent_shape) def __call__( self, @@ -334,6 +346,9 @@ def __call__( with torch.no_grad(): latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + if self.ldm_latent_shape is not None: + latent = self.ldm_resizer(latent) + prediction = super().__call__( inputs=latent, diffusion_model=diffusion_model, @@ -386,6 +401,10 @@ def sample( else: latent = outputs + if self.ldm_latent_shape is not None: + latent = self.autoencoder_resizer(latent) + latent_intermediates = [self.autoencoder_resizer(l) for l in latent_intermediates] + image = autoencoder_model.decode_stage_2_outputs(latent / self.scale_factor) if save_intermediates: @@ -437,6 +456,10 @@ def get_likelihood( f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" ) latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if self.ldm_latent_shape is not None: + latents = self.ldm_resizer(latents) + outputs = super().get_likelihood( inputs=latents, diffusion_model=diffusion_model, @@ -453,7 +476,6 @@ def get_likelihood( outputs = (outputs[0], intermediates) return outputs - class VQVAETransformerInferer(Inferer): """ Class to perform inference with a VQVAE + Transformer model. diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index 296e9266..ba607c34 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -103,6 +103,89 @@ (1, 3, 4, 4, 4), ], ] +TEST_CASES_DIFF_SHAPES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "num_channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "num_channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "num_channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 12, 12, 12), + (1, 3, 8, 8, 8), + ], +] class TestDiffusionSamplingInferer(unittest.TestCase): @@ -325,6 +408,40 @@ def test_sample_shape_conditioned_concat( ) self.assertEqual(sample.shape, input_shape) + @parameterized.expand(TEST_CASES_DIFF_SHAPES) + def test_sample_shape_different_latents(self, + model_type, + autoencoder_params, + stage_2_params, + input_shape, + latent_shape + ): + if model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + autoencoder_latent_shape = [i//(2**(len(autoencoder_params['num_channels'])-1)) for i in input_shape[2:]] + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + prediction = inferer( + inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps + ) + self.assertEqual(prediction.shape, latent_shape) if __name__ == "__main__": unittest.main()