From fb1f91d6a08de0f54d4504e07fb3efed5af5bae0 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 15 Feb 2023 13:00:09 -0600 Subject: [PATCH 1/7] Adds tranformer get_likelihood --- generative/inferers/inferer.py | 58 ++++++++++++++++++++++++++ tests/test_vqvaetransformer_inferer.py | 41 +++++++++++++++++- 2 files changed, 97 insertions(+), 2 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index f55528a6..5c3f205c 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -521,3 +521,61 @@ def sample( latent = latent_seq.reshape((starting_tokens.shape[0],) + latent_spatial_dim) return vqvae_model.decode_samples(latent) + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + vqvae_model: Callable[..., torch.Tensor], + transformer_model: Callable[..., torch.Tensor], + ordering: Callable[..., torch.Tensor], + condition: torch.Tensor | None = None, + resample_latent_likelihoods: bool | None = False, + resample_interpolation_mode: str | None = "trilinear", + ) -> torch.Tensor: + """ + Computes the likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + vqvae_model: first stage model. + transformer_model: autoregressive transformer model. + ordering: ordering of the quantised latent representation. + condition: conditioning for network input. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', or 'trilinear; + """ + if resample_interpolation_mode not in ('nearest', 'bilinear', 'trilinear'): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) + with torch.no_grad(): + latent = vqvae_model.index_quantize(inputs) + + latent_spatial_dim = tuple(latent.shape[1:]) + latent = latent.reshape(latent.shape[0], -1) + latent = latent[:, ordering.get_sequence_ordering()] + + # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token. + # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens. + latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) + latent = latent.long() + + logits = transformer_model(x=latent, context=condition) + all_probs = F.softmax(logits, dim=-1) + probs = torch.gather(all_probs, 2, latent.unsqueeze(2)) + probs = probs.squeeze(2) + + # remove the probability of the starting token + probs = probs[:,1:] + probs = torch.log(probs) + + # reshape + probs = probs[:, ordering.get_revert_sequence_ordering()] + probs_reshaped = probs.reshape((inputs.shape[0],) + latent_spatial_dim) + if resample_latent_likelihoods: + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) + probs_reshaped = resizer(probs_reshaped[:,None,...]) + + return probs_reshaped diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py index 0c10a31a..a3c054a1 100644 --- a/tests/test_vqvaetransformer_inferer.py +++ b/tests/test_vqvaetransformer_inferer.py @@ -45,6 +45,7 @@ {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 2, "dimensions": (2, 2, 2)}, (2, 1, 8, 8), (2, 5, 17), + (2,2,2) ], [ { @@ -70,13 +71,15 @@ {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 3, "dimensions": (2, 2, 2, 2)}, (2, 1, 8, 8, 8), (2, 9, 17), + (2, 2, 2, 2) + ], ] class TestVQVAETransformerInferer(unittest.TestCase): @parameterized.expand(TEST_CASES) - def test_prediction_shape(self, stage_1_params, stage_2_params, ordering_params, input_shape, latent_shape): + def test_prediction_shape(self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape): stage_1 = VQVAE(**stage_1_params) stage_2 = DecoderOnlyTransformer(**stage_2_params) ordering = Ordering(**ordering_params) @@ -91,7 +94,7 @@ def test_prediction_shape(self, stage_1_params, stage_2_params, ordering_params, inferer = VQVAETransformerInferer() prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) - self.assertEqual(prediction.shape, latent_shape) + self.assertEqual(prediction.shape, logits_shape) def test_sample(self): stage_1 = VQVAE( @@ -135,6 +138,40 @@ def test_sample(self): ) self.assertEqual(sample.shape, (2, 1, 8, 8)) + @parameterized.expand(TEST_CASES) + def test_get_likelihood(self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) + self.assertEqual(likelihood.shape, latent_shape) + @parameterized.expand(TEST_CASES) + def test_get_likelihood_resampling(self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering, resample_latent_likelihoods=True, resample_interpolation_mode='nearest') + self.assertEqual(likelihood.shape, input_shape) if __name__ == "__main__": unittest.main() From 2653a1d40d069fa3db919b39f2bab94d0755f687 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 15 Feb 2023 14:17:13 -0600 Subject: [PATCH 2/7] Code formatting --- generative/inferers/inferer.py | 6 ++--- tests/test_vqvaetransformer_inferer.py | 32 +++++++++++++++++++------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 5c3f205c..d8b4184e 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -546,7 +546,7 @@ def get_likelihood( dimension as the input images. resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', or 'trilinear; """ - if resample_interpolation_mode not in ('nearest', 'bilinear', 'trilinear'): + if resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): raise ValueError( f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" ) @@ -568,7 +568,7 @@ def get_likelihood( probs = probs.squeeze(2) # remove the probability of the starting token - probs = probs[:,1:] + probs = probs[:, 1:] probs = torch.log(probs) # reshape @@ -576,6 +576,6 @@ def get_likelihood( probs_reshaped = probs.reshape((inputs.shape[0],) + latent_spatial_dim) if resample_latent_likelihoods: resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) - probs_reshaped = resizer(probs_reshaped[:,None,...]) + probs_reshaped = resizer(probs_reshaped[:, None, ...]) return probs_reshaped diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py index a3c054a1..972885f6 100644 --- a/tests/test_vqvaetransformer_inferer.py +++ b/tests/test_vqvaetransformer_inferer.py @@ -45,7 +45,7 @@ {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 2, "dimensions": (2, 2, 2)}, (2, 1, 8, 8), (2, 5, 17), - (2,2,2) + (2, 2, 2), ], [ { @@ -71,15 +71,16 @@ {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 3, "dimensions": (2, 2, 2, 2)}, (2, 1, 8, 8, 8), (2, 9, 17), - (2, 2, 2, 2) - + (2, 2, 2, 2), ], ] class TestVQVAETransformerInferer(unittest.TestCase): @parameterized.expand(TEST_CASES) - def test_prediction_shape(self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape): + def test_prediction_shape( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): stage_1 = VQVAE(**stage_1_params) stage_2 = DecoderOnlyTransformer(**stage_2_params) ordering = Ordering(**ordering_params) @@ -139,7 +140,9 @@ def test_sample(self): self.assertEqual(sample.shape, (2, 1, 8, 8)) @parameterized.expand(TEST_CASES) - def test_get_likelihood(self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape): + def test_get_likelihood( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): stage_1 = VQVAE(**stage_1_params) stage_2 = DecoderOnlyTransformer(**stage_2_params) ordering = Ordering(**ordering_params) @@ -153,10 +156,15 @@ def test_get_likelihood(self, stage_1_params, stage_2_params, ordering_params, i input = torch.randn(input_shape).to(device) inferer = VQVAETransformerInferer() - likelihood = inferer.get_likelihood(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) + likelihood = inferer.get_likelihood( + inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering + ) self.assertEqual(likelihood.shape, latent_shape) + @parameterized.expand(TEST_CASES) - def test_get_likelihood_resampling(self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape): + def test_get_likelihood_resampling( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): stage_1 = VQVAE(**stage_1_params) stage_2 = DecoderOnlyTransformer(**stage_2_params) ordering = Ordering(**ordering_params) @@ -170,8 +178,16 @@ def test_get_likelihood_resampling(self, stage_1_params, stage_2_params, orderin input = torch.randn(input_shape).to(device) inferer = VQVAETransformerInferer() - likelihood = inferer.get_likelihood(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering, resample_latent_likelihoods=True, resample_interpolation_mode='nearest') + likelihood = inferer.get_likelihood( + inputs=input, + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + resample_latent_likelihoods=True, + resample_interpolation_mode="nearest", + ) self.assertEqual(likelihood.shape, input_shape) + if __name__ == "__main__": unittest.main() From 0603e987b19c5153b851f35a12d81d5bbd9c2594 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 15 Feb 2023 14:45:39 -0600 Subject: [PATCH 3/7] Addressed comments --- generative/inferers/inferer.py | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index d8b4184e..940be063 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -432,7 +432,8 @@ def __call__( transformer_model: Callable[..., torch.Tensor], ordering: Callable[..., torch.Tensor], condition: torch.Tensor | None = None, - ) -> torch.Tensor: + return_latent: bool = False + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ Implements the forward pass for a supervised training iteration. @@ -441,11 +442,13 @@ def __call__( vqvae_model: first stage model. transformer_model: autoregressive transformer model. ordering: ordering of the quantised latent representation. + return_latent: also return latent sequence and spatial dim of the latent. condition: conditioning for network input. """ with torch.no_grad(): latent = vqvae_model.index_quantize(inputs) + latent_spatial_dim = tuple(latent.shape[1:]) latent = latent.reshape(latent.shape[0], -1) latent = latent[:, ordering.get_sequence_ordering()] @@ -455,8 +458,10 @@ def __call__( latent = latent.long() prediction = transformer_model(x=latent, context=condition) - - return prediction + if return_latent: + return prediction, latent, latent_spatial_dim + else: + return prediction @torch.no_grad() def sample( @@ -531,10 +536,10 @@ def get_likelihood( ordering: Callable[..., torch.Tensor], condition: torch.Tensor | None = None, resample_latent_likelihoods: bool | None = False, - resample_interpolation_mode: str | None = "trilinear", + resample_interpolation_mode: str = "trilinear", ) -> torch.Tensor: """ - Computes the likelihoods of the latent representations of the input. + Computes the log-likelihoods of the latent representations of the input. Args: inputs: input images, NxCxHxW[xD] @@ -550,19 +555,7 @@ def get_likelihood( raise ValueError( f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" ) - with torch.no_grad(): - latent = vqvae_model.index_quantize(inputs) - - latent_spatial_dim = tuple(latent.shape[1:]) - latent = latent.reshape(latent.shape[0], -1) - latent = latent[:, ordering.get_sequence_ordering()] - - # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token. - # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens. - latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) - latent = latent.long() - - logits = transformer_model(x=latent, context=condition) + logits, latent, latent_spatial_dim = self(inputs=inputs, vqvae_model=vqvae_model, transformer_model=transformer_model, ordering = ordering, condition=condition, return_latent=True) all_probs = F.softmax(logits, dim=-1) probs = torch.gather(all_probs, 2, latent.unsqueeze(2)) probs = probs.squeeze(2) From c5a1ace1a1a18aa083ed131109413d09b0da4167 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 15 Feb 2023 14:46:19 -0600 Subject: [PATCH 4/7] Fixes formatting --- generative/inferers/inferer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 940be063..8cd2043c 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -432,7 +432,7 @@ def __call__( transformer_model: Callable[..., torch.Tensor], ordering: Callable[..., torch.Tensor], condition: torch.Tensor | None = None, - return_latent: bool = False + return_latent: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ Implements the forward pass for a supervised training iteration. @@ -555,7 +555,14 @@ def get_likelihood( raise ValueError( f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" ) - logits, latent, latent_spatial_dim = self(inputs=inputs, vqvae_model=vqvae_model, transformer_model=transformer_model, ordering = ordering, condition=condition, return_latent=True) + logits, latent, latent_spatial_dim = self( + inputs=inputs, + vqvae_model=vqvae_model, + transformer_model=transformer_model, + ordering=ordering, + condition=condition, + return_latent=True, + ) all_probs = F.softmax(logits, dim=-1) probs = torch.gather(all_probs, 2, latent.unsqueeze(2)) probs = probs.squeeze(2) From ac71f7dbf25191d130806817f360f225a56da2bd Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 15 Feb 2023 15:35:06 -0600 Subject: [PATCH 5/7] Fixes typing --- generative/inferers/inferer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 8cd2043c..391dd49a 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -433,7 +433,7 @@ def __call__( ordering: Callable[..., torch.Tensor], condition: torch.Tensor | None = None, return_latent: bool = False, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, tuple]: """ Implements the forward pass for a supervised training iteration. From 870cb7727071838b21c21a610cc3fe99a4f577b2 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 15 Feb 2023 16:58:48 -0600 Subject: [PATCH 6/7] Fix argument type --- generative/inferers/inferer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 391dd49a..86c1fd4e 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -535,7 +535,7 @@ def get_likelihood( transformer_model: Callable[..., torch.Tensor], ordering: Callable[..., torch.Tensor], condition: torch.Tensor | None = None, - resample_latent_likelihoods: bool | None = False, + resample_latent_likelihoods: bool = False, resample_interpolation_mode: str = "trilinear", ) -> torch.Tensor: """ From 8e3d72fca7d0c92bede0c393617d132124a71fac Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 15 Feb 2023 17:16:25 -0600 Subject: [PATCH 7/7] Fixes docstring --- generative/inferers/inferer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 86c1fd4e..b44ac11f 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -549,7 +549,8 @@ def get_likelihood( condition: conditioning for network input. resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial dimension as the input images. - resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', or 'trilinear; + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; """ if resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): raise ValueError(