diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index b44ac11f..5d06b12a 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -117,7 +117,7 @@ def get_likelihood( verbose: bool | None = True, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ - Computes the likelihoods for an input. + Computes the log-likelihoods for an input. Args: inputs: input images, NxCxHxW[xD] @@ -372,11 +372,11 @@ def get_likelihood( original_input_range: tuple | None = (0, 255), scaled_input_range: tuple | None = (0, 1), verbose: bool | None = True, - resample_latent_likelihoods: bool | None = False, - resample_interpolation_mode: str | None = "bilinear", + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ - Computes the likelihoods of the latent representations of the input. + Computes the log-likelihoods of the latent representations of the input. Args: inputs: input images, NxCxHxW[xD] @@ -390,9 +390,13 @@ def get_likelihood( verbose: if true, prints the progression bar of the sampling process. resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial dimension as the input images. - resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest' or 'bilinear' + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; """ - + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor outputs = super().get_likelihood( inputs=latents, @@ -404,14 +408,7 @@ def get_likelihood( ) if save_intermediates and resample_latent_likelihoods: intermediates = outputs[1] - from torchvision.transforms import Resize - - interpolation_modes = {"nearest": 0, "bilinear": 2} - if resample_interpolation_mode not in interpolation_modes.keys(): - raise ValueError( - f"resample_interpolation mode should be either nearest or bilinear, not {resample_interpolation_mode}" - ) - resizer = Resize(size=inputs.shape[2:], interpolation=interpolation_modes[resample_interpolation_mode]) + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) intermediates = [resizer(x) for x in intermediates] outputs = (outputs[0], intermediates) return outputs @@ -536,7 +533,7 @@ def get_likelihood( ordering: Callable[..., torch.Tensor], condition: torch.Tensor | None = None, resample_latent_likelihoods: bool = False, - resample_interpolation_mode: str = "trilinear", + resample_interpolation_mode: str = "nearest", ) -> torch.Tensor: """ Computes the log-likelihoods of the latent representations of the input. @@ -552,7 +549,7 @@ def get_likelihood( resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', or 'trilinear; """ - if resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): raise ValueError( f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" ) diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index 9623c76a..a049e1d5 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -75,6 +75,33 @@ (1, 1, 16, 16), (1, 3, 4, 4), ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "num_channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 16, 16, 16), + (1, 3, 4, 4, 4), + ], ]