diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index f55528a6..b44ac11f 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -432,7 +432,8 @@ def __call__( transformer_model: Callable[..., torch.Tensor], ordering: Callable[..., torch.Tensor], condition: torch.Tensor | None = None, - ) -> torch.Tensor: + return_latent: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, tuple]: """ Implements the forward pass for a supervised training iteration. @@ -441,11 +442,13 @@ def __call__( vqvae_model: first stage model. transformer_model: autoregressive transformer model. ordering: ordering of the quantised latent representation. + return_latent: also return latent sequence and spatial dim of the latent. condition: conditioning for network input. """ with torch.no_grad(): latent = vqvae_model.index_quantize(inputs) + latent_spatial_dim = tuple(latent.shape[1:]) latent = latent.reshape(latent.shape[0], -1) latent = latent[:, ordering.get_sequence_ordering()] @@ -455,8 +458,10 @@ def __call__( latent = latent.long() prediction = transformer_model(x=latent, context=condition) - - return prediction + if return_latent: + return prediction, latent, latent_spatial_dim + else: + return prediction @torch.no_grad() def sample( @@ -521,3 +526,57 @@ def sample( latent = latent_seq.reshape((starting_tokens.shape[0],) + latent_spatial_dim) return vqvae_model.decode_samples(latent) + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + vqvae_model: Callable[..., torch.Tensor], + transformer_model: Callable[..., torch.Tensor], + ordering: Callable[..., torch.Tensor], + condition: torch.Tensor | None = None, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "trilinear", + ) -> torch.Tensor: + """ + Computes the log-likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + vqvae_model: first stage model. + transformer_model: autoregressive transformer model. + ordering: ordering of the quantised latent representation. + condition: conditioning for network input. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; + """ + if resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) + logits, latent, latent_spatial_dim = self( + inputs=inputs, + vqvae_model=vqvae_model, + transformer_model=transformer_model, + ordering=ordering, + condition=condition, + return_latent=True, + ) + all_probs = F.softmax(logits, dim=-1) + probs = torch.gather(all_probs, 2, latent.unsqueeze(2)) + probs = probs.squeeze(2) + + # remove the probability of the starting token + probs = probs[:, 1:] + probs = torch.log(probs) + + # reshape + probs = probs[:, ordering.get_revert_sequence_ordering()] + probs_reshaped = probs.reshape((inputs.shape[0],) + latent_spatial_dim) + if resample_latent_likelihoods: + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) + probs_reshaped = resizer(probs_reshaped[:, None, ...]) + + return probs_reshaped diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py index 0c10a31a..972885f6 100644 --- a/tests/test_vqvaetransformer_inferer.py +++ b/tests/test_vqvaetransformer_inferer.py @@ -45,6 +45,7 @@ {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 2, "dimensions": (2, 2, 2)}, (2, 1, 8, 8), (2, 5, 17), + (2, 2, 2), ], [ { @@ -70,13 +71,16 @@ {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 3, "dimensions": (2, 2, 2, 2)}, (2, 1, 8, 8, 8), (2, 9, 17), + (2, 2, 2, 2), ], ] class TestVQVAETransformerInferer(unittest.TestCase): @parameterized.expand(TEST_CASES) - def test_prediction_shape(self, stage_1_params, stage_2_params, ordering_params, input_shape, latent_shape): + def test_prediction_shape( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): stage_1 = VQVAE(**stage_1_params) stage_2 = DecoderOnlyTransformer(**stage_2_params) ordering = Ordering(**ordering_params) @@ -91,7 +95,7 @@ def test_prediction_shape(self, stage_1_params, stage_2_params, ordering_params, inferer = VQVAETransformerInferer() prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) - self.assertEqual(prediction.shape, latent_shape) + self.assertEqual(prediction.shape, logits_shape) def test_sample(self): stage_1 = VQVAE( @@ -135,6 +139,55 @@ def test_sample(self): ) self.assertEqual(sample.shape, (2, 1, 8, 8)) + @parameterized.expand(TEST_CASES) + def test_get_likelihood( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood( + inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering + ) + self.assertEqual(likelihood.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_get_likelihood_resampling( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood( + inputs=input, + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + resample_latent_likelihoods=True, + resample_interpolation_mode="nearest", + ) + self.assertEqual(likelihood.shape, input_shape) + if __name__ == "__main__": unittest.main()