From 1942fac2fd54dca959941da421bda869284dbbaa Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 16 Feb 2023 17:13:11 -0600 Subject: [PATCH 1/4] Clarifies we calculate log-likelihoods in the docs --- generative/inferers/inferer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index b44ac11f..6aea2f71 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -117,7 +117,7 @@ def get_likelihood( verbose: bool | None = True, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ - Computes the likelihoods for an input. + Computes the log-likelihoods for an input. Args: inputs: input images, NxCxHxW[xD] @@ -376,7 +376,7 @@ def get_likelihood( resample_interpolation_mode: str | None = "bilinear", ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ - Computes the likelihoods of the latent representations of the input. + Computes the log-likelihoods of the latent representations of the input. Args: inputs: input images, NxCxHxW[xD] From e4bfe15a17f211a3f7c7a6d3e69fe16892058518 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 16 Feb 2023 17:22:33 -0600 Subject: [PATCH 2/4] Adds a 3D test case --- tests/test_latent_diffusion_inferer.py | 27 ++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index 9623c76a..a049e1d5 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -75,6 +75,33 @@ (1, 1, 16, 16), (1, 3, 4, 4), ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "num_channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 16, 16, 16), + (1, 3, 4, 4, 4), + ], ] From 6b12bc9ca6d1b554a2f7beebec7906df469362a3 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 16 Feb 2023 17:24:52 -0600 Subject: [PATCH 3/4] Harmonises resampling for diffusion and transformer infererers --- generative/inferers/inferer.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 6aea2f71..3245f7b7 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -373,7 +373,7 @@ def get_likelihood( scaled_input_range: tuple | None = (0, 1), verbose: bool | None = True, resample_latent_likelihoods: bool | None = False, - resample_interpolation_mode: str | None = "bilinear", + resample_interpolation_mode: str | None = "nearest", ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Computes the log-likelihoods of the latent representations of the input. @@ -390,9 +390,13 @@ def get_likelihood( verbose: if true, prints the progression bar of the sampling process. resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial dimension as the input images. - resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest' or 'bilinear' + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; """ - + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor outputs = super().get_likelihood( inputs=latents, @@ -404,14 +408,7 @@ def get_likelihood( ) if save_intermediates and resample_latent_likelihoods: intermediates = outputs[1] - from torchvision.transforms import Resize - - interpolation_modes = {"nearest": 0, "bilinear": 2} - if resample_interpolation_mode not in interpolation_modes.keys(): - raise ValueError( - f"resample_interpolation mode should be either nearest or bilinear, not {resample_interpolation_mode}" - ) - resizer = Resize(size=inputs.shape[2:], interpolation=interpolation_modes[resample_interpolation_mode]) + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) intermediates = [resizer(x) for x in intermediates] outputs = (outputs[0], intermediates) return outputs @@ -536,7 +533,7 @@ def get_likelihood( ordering: Callable[..., torch.Tensor], condition: torch.Tensor | None = None, resample_latent_likelihoods: bool = False, - resample_interpolation_mode: str = "trilinear", + resample_interpolation_mode: str = "nearest", ) -> torch.Tensor: """ Computes the log-likelihoods of the latent representations of the input. @@ -552,7 +549,7 @@ def get_likelihood( resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', or 'trilinear; """ - if resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): raise ValueError( f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" ) From 09de310e745847e7049fe5a562ab66cce9b862cc Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 16 Feb 2023 17:37:36 -0600 Subject: [PATCH 4/4] Apply suggestions from code review Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- generative/inferers/inferer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 3245f7b7..5d06b12a 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -372,8 +372,8 @@ def get_likelihood( original_input_range: tuple | None = (0, 255), scaled_input_range: tuple | None = (0, 1), verbose: bool | None = True, - resample_latent_likelihoods: bool | None = False, - resample_interpolation_mode: str | None = "nearest", + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Computes the log-likelihoods of the latent representations of the input.