From 1565a32f9ab152d0beca9ec83304fe9b150329c9 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 2 Mar 2023 10:04:09 -0600 Subject: [PATCH 1/6] Run on cropped sequence if exceeds max length --- generative/inferers/inferer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 40c3459b..190da244 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -454,9 +454,16 @@ def __call__( latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) latent = latent.long() - prediction = transformer_model(x=latent, context=condition) + # train on a part of the sequence if it is longer than max_seq_length + seq_len = latent.shape[1] + max_seq_len = transformer_model.max_seq_len + if max_seq_len < seq_len: + start = torch.randint(low=0, high=seq_len + 1 -max_seq_len, size=(1,)).item() + else: + start = 0 + prediction = transformer_model(x=latent[:,start:start+max_seq_len], context=condition) if return_latent: - return prediction, latent, latent_spatial_dim + return prediction, latent[:, start:start + max_seq_len], latent_spatial_dim else: return prediction From aeaa52bc40256fcb0c0555c8f28b18ce9cee3b02 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 2 Mar 2023 10:04:20 -0600 Subject: [PATCH 2/6] Adds test --- tests/test_vqvaetransformer_inferer.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py index 972885f6..199181c5 100644 --- a/tests/test_vqvaetransformer_inferer.py +++ b/tests/test_vqvaetransformer_inferer.py @@ -97,6 +97,29 @@ def test_prediction_shape( prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) self.assertEqual(prediction.shape, logits_shape) + @parameterized.expand(TEST_CASES) + def test_prediction_shape_shorter_sequence( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + max_seq_len = 3 + stage_2_params['max_seq_len'] = max_seq_len + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) + cropped_logits_shape = (logits_shape[0], max_seq_len, logits_shape[2]) + self.assertEqual(prediction.shape, cropped_logits_shape) + def test_sample(self): stage_1 = VQVAE( spatial_dims=2, From f164e4c47b28aa1d04fcf359caec7cf44af6f0ec Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 2 Mar 2023 10:05:47 -0600 Subject: [PATCH 3/6] Adds sampling test --- tests/test_vqvaetransformer_inferer.py | 42 ++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py index 199181c5..33b9b7a2 100644 --- a/tests/test_vqvaetransformer_inferer.py +++ b/tests/test_vqvaetransformer_inferer.py @@ -162,6 +162,48 @@ def test_sample(self): ) self.assertEqual(sample.shape, (2, 1, 8, 8)) + def test_sample_shorter_sequence(self): + stage_1 = VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(8, 8), + num_res_channels=(8, 8), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + num_res_layers=1, + num_embeddings=16, + embedding_dim=8, + ) + stage_2 = DecoderOnlyTransformer( + num_tokens=16 + 1, + max_seq_len=3, + attn_layers_dim=4, + attn_layers_depth=2, + attn_layers_heads=1, + with_cross_attention=False, + ) + ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2)) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + inferer = VQVAETransformerInferer() + + starting_token = 16 # from stage_1 num_embeddings + + sample = inferer.sample( + latent_spatial_dim=(2, 2), + starting_tokens=starting_token * torch.ones((2, 1), device=device), + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + ) + self.assertEqual(sample.shape, (2, 1, 8, 8)) + @parameterized.expand(TEST_CASES) def test_get_likelihood( self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape From 65e076a54df3ecb1abc91f09f2615ba4fbb67783 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 2 Mar 2023 12:54:35 -0600 Subject: [PATCH 4/6] Add likelihood tests --- tests/test_vqvaetransformer_inferer.py | 29 ++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py index 33b9b7a2..d901d307 100644 --- a/tests/test_vqvaetransformer_inferer.py +++ b/tests/test_vqvaetransformer_inferer.py @@ -103,8 +103,9 @@ def test_prediction_shape_shorter_sequence( ): stage_1 = VQVAE(**stage_1_params) max_seq_len = 3 - stage_2_params['max_seq_len'] = max_seq_len - stage_2 = DecoderOnlyTransformer(**stage_2_params) + stage_2_params_shorter = {k: v for k, v in stage_2_params.items()} + stage_2_params_shorter['max_seq_len'] = max_seq_len + stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter) ordering = Ordering(**ordering_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -227,6 +228,30 @@ def test_get_likelihood( self.assertEqual(likelihood.shape, latent_shape) @parameterized.expand(TEST_CASES) + def test_get_likelihood_shorter_sequence( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + max_seq_len = 3 + stage_2_params_shorter = {k: v for k, v in stage_2_params.items()} + stage_2_params_shorter['max_seq_len'] = max_seq_len + stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood( + inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering + ) + self.assertEqual(likelihood.shape, latent_shape) + @parameterized.expand(TEST_CASES) def test_get_likelihood_resampling( self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape ): From c604919daa388338c8e02efe5e3d201b541d0462 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 2 Mar 2023 12:54:56 -0600 Subject: [PATCH 5/6] Updates likelihood inferer --- generative/inferers/inferer.py | 53 ++++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 12 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 190da244..85f6a0ff 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -541,6 +541,7 @@ def get_likelihood( condition: torch.Tensor | None = None, resample_latent_likelihoods: bool = False, resample_interpolation_mode: str = "nearest", + verbose: bool = False, ) -> torch.Tensor: """ Computes the log-likelihoods of the latent representations of the input. @@ -555,24 +556,52 @@ def get_likelihood( dimension as the input images. resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', or 'trilinear; + verbose: if true, prints the progression bar of the sampling process. + """ if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): raise ValueError( f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" ) - logits, latent, latent_spatial_dim = self( - inputs=inputs, - vqvae_model=vqvae_model, - transformer_model=transformer_model, - ordering=ordering, - condition=condition, - return_latent=True, - ) - all_probs = F.softmax(logits, dim=-1) - probs = torch.gather(all_probs, 2, latent.unsqueeze(2)) - probs = probs.squeeze(2) - # remove the probability of the starting token + with torch.no_grad(): + latent = vqvae_model.index_quantize(inputs) + + latent_spatial_dim = tuple(latent.shape[1:]) + latent = latent.reshape(latent.shape[0], -1) + latent = latent[:, ordering.get_sequence_ordering()] + seq_len = math.prod(latent_spatial_dim) + + # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token. + # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens. + latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) + latent = latent.long() + + # get the first batch, up to max_seq_length, efficiently + logits = transformer_model(x=latent[:,:transformer_model.max_seq_len], context=condition) + probs = F.softmax(logits, dim=-1) + probs = torch.gather(probs, 2, latent[:, :transformer_model.max_seq_len].unsqueeze(2)).squeeze(2) + + # if we have not covered the full sequence we continue with inefficient looping + if probs.shape[1] < latent.shape[1]: + if verbose and has_tqdm: + progress_bar = tqdm(range(transformer_model.max_seq_len,seq_len+1)) + else: + progress_bar = iter(range(transformer_model.max_seq_len,seq_len+1)) + + for i in progress_bar: + idx_cond = latent[:,i+1-transformer_model.max_seq_len:i+1] + # forward the model to get the logits for the index in the sequence + logits = transformer_model(x=idx_cond, context=condition) + # pluck the logits at the final step + logits = logits[:, -1, :] + # apply softmax to convert logits to (normalized) probabilities + p = F.softmax(logits, dim=-1) + # select correct values and append + p = torch.gather(p, 1, idx_cond[:,-1].unsqueeze(1)) + probs = torch.cat((probs, p), dim=1) + + # remove starting token probability probs = probs[:, 1:] probs = torch.log(probs) From d093cc75dd6153861014dba9f0ca7328419594e5 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 2 Mar 2023 12:56:53 -0600 Subject: [PATCH 6/6] Corrects code formattting --- generative/inferers/inferer.py | 18 +++++++++--------- tests/test_vqvaetransformer_inferer.py | 5 +++-- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 85f6a0ff..2f3e2424 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -458,12 +458,12 @@ def __call__( seq_len = latent.shape[1] max_seq_len = transformer_model.max_seq_len if max_seq_len < seq_len: - start = torch.randint(low=0, high=seq_len + 1 -max_seq_len, size=(1,)).item() + start = torch.randint(low=0, high=seq_len + 1 - max_seq_len, size=(1,)).item() else: start = 0 - prediction = transformer_model(x=latent[:,start:start+max_seq_len], context=condition) + prediction = transformer_model(x=latent[:, start : start + max_seq_len], context=condition) if return_latent: - return prediction, latent[:, start:start + max_seq_len], latent_spatial_dim + return prediction, latent[:, start : start + max_seq_len], latent_spatial_dim else: return prediction @@ -578,19 +578,19 @@ def get_likelihood( latent = latent.long() # get the first batch, up to max_seq_length, efficiently - logits = transformer_model(x=latent[:,:transformer_model.max_seq_len], context=condition) + logits = transformer_model(x=latent[:, : transformer_model.max_seq_len], context=condition) probs = F.softmax(logits, dim=-1) - probs = torch.gather(probs, 2, latent[:, :transformer_model.max_seq_len].unsqueeze(2)).squeeze(2) + probs = torch.gather(probs, 2, latent[:, : transformer_model.max_seq_len].unsqueeze(2)).squeeze(2) # if we have not covered the full sequence we continue with inefficient looping if probs.shape[1] < latent.shape[1]: if verbose and has_tqdm: - progress_bar = tqdm(range(transformer_model.max_seq_len,seq_len+1)) + progress_bar = tqdm(range(transformer_model.max_seq_len, seq_len + 1)) else: - progress_bar = iter(range(transformer_model.max_seq_len,seq_len+1)) + progress_bar = iter(range(transformer_model.max_seq_len, seq_len + 1)) for i in progress_bar: - idx_cond = latent[:,i+1-transformer_model.max_seq_len:i+1] + idx_cond = latent[:, i + 1 - transformer_model.max_seq_len : i + 1] # forward the model to get the logits for the index in the sequence logits = transformer_model(x=idx_cond, context=condition) # pluck the logits at the final step @@ -598,7 +598,7 @@ def get_likelihood( # apply softmax to convert logits to (normalized) probabilities p = F.softmax(logits, dim=-1) # select correct values and append - p = torch.gather(p, 1, idx_cond[:,-1].unsqueeze(1)) + p = torch.gather(p, 1, idx_cond[:, -1].unsqueeze(1)) probs = torch.cat((probs, p), dim=1) # remove starting token probability diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py index d901d307..87766811 100644 --- a/tests/test_vqvaetransformer_inferer.py +++ b/tests/test_vqvaetransformer_inferer.py @@ -104,7 +104,7 @@ def test_prediction_shape_shorter_sequence( stage_1 = VQVAE(**stage_1_params) max_seq_len = 3 stage_2_params_shorter = {k: v for k, v in stage_2_params.items()} - stage_2_params_shorter['max_seq_len'] = max_seq_len + stage_2_params_shorter["max_seq_len"] = max_seq_len stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter) ordering = Ordering(**ordering_params) @@ -234,7 +234,7 @@ def test_get_likelihood_shorter_sequence( stage_1 = VQVAE(**stage_1_params) max_seq_len = 3 stage_2_params_shorter = {k: v for k, v in stage_2_params.items()} - stage_2_params_shorter['max_seq_len'] = max_seq_len + stage_2_params_shorter["max_seq_len"] = max_seq_len stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter) ordering = Ordering(**ordering_params) @@ -251,6 +251,7 @@ def test_get_likelihood_shorter_sequence( inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering ) self.assertEqual(likelihood.shape, latent_shape) + @parameterized.expand(TEST_CASES) def test_get_likelihood_resampling( self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape