From 08aef5081ab07332e1d01b49ea449e82818f4650 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Fri, 14 Feb 2025 14:36:37 +0000 Subject: [PATCH 1/8] There was a bug in the LatentDiffusionInferer and ControlnetLatentDiffusionInferer when save_intermediates was off but there's need to pad and crop latent space. Signed-off-by: Virginia Fernandez --- monai/inferers/inferer.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 769b6cc0e7..7bedc18985 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -18,7 +18,6 @@ from functools import partial from pydoc import locate from typing import Any - import torch import torch.nn as nn import torch.nn.functional as F @@ -1202,9 +1201,10 @@ def sample( # type: ignore[override] if self.autoencoder_latent_shape is not None: latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) - latent_intermediates = [ - torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates - ] + if save_intermediates: + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + ] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL): @@ -1727,9 +1727,10 @@ def sample( # type: ignore[override] if self.autoencoder_latent_shape is not None: latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) - latent_intermediates = [ - torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates - ] + if save_intermediates: + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + ] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL): From e5aa008cc7a8d911695126f5ea785949a4144519 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Fri, 14 Feb 2025 14:38:21 +0000 Subject: [PATCH 2/8] There was a bug in the LatentDiffusionInferer and ControlnetLatentDiffusionInferer when save_intermediates was off but there's need to pad and crop latent space. Signed-off-by: Virginia Fernandez --- monai/inferers/inferer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 7bedc18985..8bea63418b 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -18,6 +18,7 @@ from functools import partial from pydoc import locate from typing import Any + import torch import torch.nn as nn import torch.nn.functional as F @@ -1203,7 +1204,8 @@ def sample( # type: ignore[override] latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) if save_intermediates: latent_intermediates = [ - torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) + for l in latent_intermediates ] decode = autoencoder_model.decode_stage_2_outputs @@ -1729,7 +1731,8 @@ def sample( # type: ignore[override] latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) if save_intermediates: latent_intermediates = [ - torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) + for l in latent_intermediates ] decode = autoencoder_model.decode_stage_2_outputs From ae32d2ef55f697c63a1883b2074239e57b3e55b1 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Tue, 18 Feb 2025 08:01:10 +0000 Subject: [PATCH 3/8] Added tests for new functionality. Signed-off-by: Virginia Fernandez --- monai/inferers/inferer.py | 4 +- tests/test_controlnet_inferers.py | 82 +++++++++++++++++++++++++- tests/test_latent_diffusion_inferer.py | 62 ++++++++++++++++++- 3 files changed, 144 insertions(+), 4 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 8bea63418b..e48fd1619d 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1211,8 +1211,10 @@ def sample( # type: ignore[override] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL): decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + print("Decoding latents...") + print(latent.shape) image = decode(latent / self.scale_factor) - + print(image.shape) if save_intermediates: intermediates = [] for latent_intermediate in latent_intermediates: diff --git a/tests/test_controlnet_inferers.py b/tests/test_controlnet_inferers.py index e3b0aeb5a2..2ab5cec335 100644 --- a/tests/test_controlnet_inferers.py +++ b/tests/test_controlnet_inferers.py @@ -722,7 +722,7 @@ def test_prediction_shape( @parameterized.expand(LATENT_CNDM_TEST_CASES) @skipUnless(has_einops, "Requires einops") - def test_sample_shape( + def test_pred_shape( self, ae_model_type, autoencoder_params, @@ -1165,7 +1165,7 @@ def test_sample_shape_conditioned_concat( @parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES) @skipUnless(has_einops, "Requires einops") - def test_sample_shape_different_latents( + def test_shape_different_latents( self, ae_model_type, autoencoder_params, @@ -1242,6 +1242,84 @@ def test_sample_shape_different_latents( ) self.assertEqual(prediction.shape, latent_shape) + @parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES) + @skipUnless(has_einops, "Requires einops") + def test_sample_shape_different_latents( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + if ae_model_type == "VQVAE": + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]] + else: + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] + + inferer = ControlNetLatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, + ) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction, _ = inferer.sample( + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + input_noise=noise, + seg=input_seg, + save_intermediates=True, + ) + else: + prediction = inferer.sample( + autoencoder_model=stage_1, + diffusion_model=stage_2, + input_noise=noise, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=False, + ) + self.assertEqual(prediction.shape, input_shape) + @skipUnless(has_einops, "Requires einops") def test_incompatible_spade_setup(self): stage_1 = SPADEAutoencoderKL( diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index 2e04ad6c5c..4f81b96ca1 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -714,7 +714,7 @@ def test_sample_shape_conditioned_concat( @parameterized.expand(TEST_CASES_DIFF_SHAPES) @skipUnless(has_einops, "Requires einops") - def test_sample_shape_different_latents( + def test_shape_different_latents( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): stage_1 = None @@ -772,6 +772,66 @@ def test_sample_shape_different_latents( ) self.assertEqual(prediction.shape, latent_shape) + @parameterized.expand(TEST_CASES_DIFF_SHAPES) + @skipUnless(has_einops, "Requires einops") + def test_sample_shape_different_latents( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + if ae_model_type == "VQVAE": + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]] + else: + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] + + inferer = LatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, + ) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction, _ = inferer.sample( + autoencoder_model=stage_1, + diffusion_model=stage_2, + input_noise=noise, + save_intermediates=True, + seg=input_seg, + ) + else: + prediction = inferer.sample( + autoencoder_model=stage_1, diffusion_model=stage_2, input_noise=noise, save_intermediates=False + ) + self.assertEqual(prediction.shape, input_shape) + @skipUnless(has_einops, "Requires einops") def test_incompatible_spade_setup(self): stage_1 = SPADEAutoencoderKL( From 52de9a76fd6881551160169c4ff2dd950efae598 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Tue, 18 Feb 2025 08:03:04 +0000 Subject: [PATCH 4/8] Added tests for new functionality. Signed-off-by: Virginia Fernandez --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index bffe304df4..1e20760960 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,7 +17,7 @@ mccabe pep8-naming pycodestyle pyflakes -black>=22.12 +black>=22.12 isort>=5.1 ruff pytype>=2020.6.1; platform_system != "Windows" From 0f5609e14586e6e42ae67034d094a66e7d7250e9 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Tue, 18 Feb 2025 08:03:11 +0000 Subject: [PATCH 5/8] Added tests for new functionality. Signed-off-by: Virginia Fernandez --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 1e20760960..bffe304df4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,7 +17,7 @@ mccabe pep8-naming pycodestyle pyflakes -black>=22.12 +black>=22.12 isort>=5.1 ruff pytype>=2020.6.1; platform_system != "Windows" From 22e21fe4c1c36225dbffb550379226f50b764f60 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 18 Feb 2025 18:45:05 +0800 Subject: [PATCH 6/8] Update monai/inferers/inferer.py Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/inferers/inferer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index e48fd1619d..4eb6985c8a 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1214,7 +1214,6 @@ def sample( # type: ignore[override] print("Decoding latents...") print(latent.shape) image = decode(latent / self.scale_factor) - print(image.shape) if save_intermediates: intermediates = [] for latent_intermediate in latent_intermediates: From 83e3dd63aae534183b34d4ef2691eadcd521c8be Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 18 Feb 2025 18:45:14 +0800 Subject: [PATCH 7/8] Update monai/inferers/inferer.py Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/inferers/inferer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 4eb6985c8a..7083373859 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1211,8 +1211,6 @@ def sample( # type: ignore[override] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL): decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) - print("Decoding latents...") - print(latent.shape) image = decode(latent / self.scale_factor) if save_intermediates: intermediates = [] From 4a985da20ec25c11117bf24b9d03bcf246fa07a8 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Tue, 18 Feb 2025 10:53:02 +0000 Subject: [PATCH 8/8] Remove debugging related print statements. Signed-off-by: Virginia Fernandez --- monai/inferers/inferer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index e48fd1619d..7083373859 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1211,10 +1211,7 @@ def sample( # type: ignore[override] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL): decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) - print("Decoding latents...") - print(latent.shape) image = decode(latent / self.scale_factor) - print(image.shape) if save_intermediates: intermediates = [] for latent_intermediate in latent_intermediates: