diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 40c3459b..2f3e2424 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -454,9 +454,16 @@ def __call__( latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) latent = latent.long() - prediction = transformer_model(x=latent, context=condition) + # train on a part of the sequence if it is longer than max_seq_length + seq_len = latent.shape[1] + max_seq_len = transformer_model.max_seq_len + if max_seq_len < seq_len: + start = torch.randint(low=0, high=seq_len + 1 - max_seq_len, size=(1,)).item() + else: + start = 0 + prediction = transformer_model(x=latent[:, start : start + max_seq_len], context=condition) if return_latent: - return prediction, latent, latent_spatial_dim + return prediction, latent[:, start : start + max_seq_len], latent_spatial_dim else: return prediction @@ -534,6 +541,7 @@ def get_likelihood( condition: torch.Tensor | None = None, resample_latent_likelihoods: bool = False, resample_interpolation_mode: str = "nearest", + verbose: bool = False, ) -> torch.Tensor: """ Computes the log-likelihoods of the latent representations of the input. @@ -548,24 +556,52 @@ def get_likelihood( dimension as the input images. resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', or 'trilinear; + verbose: if true, prints the progression bar of the sampling process. + """ if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): raise ValueError( f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" ) - logits, latent, latent_spatial_dim = self( - inputs=inputs, - vqvae_model=vqvae_model, - transformer_model=transformer_model, - ordering=ordering, - condition=condition, - return_latent=True, - ) - all_probs = F.softmax(logits, dim=-1) - probs = torch.gather(all_probs, 2, latent.unsqueeze(2)) - probs = probs.squeeze(2) - # remove the probability of the starting token + with torch.no_grad(): + latent = vqvae_model.index_quantize(inputs) + + latent_spatial_dim = tuple(latent.shape[1:]) + latent = latent.reshape(latent.shape[0], -1) + latent = latent[:, ordering.get_sequence_ordering()] + seq_len = math.prod(latent_spatial_dim) + + # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token. + # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens. + latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) + latent = latent.long() + + # get the first batch, up to max_seq_length, efficiently + logits = transformer_model(x=latent[:, : transformer_model.max_seq_len], context=condition) + probs = F.softmax(logits, dim=-1) + probs = torch.gather(probs, 2, latent[:, : transformer_model.max_seq_len].unsqueeze(2)).squeeze(2) + + # if we have not covered the full sequence we continue with inefficient looping + if probs.shape[1] < latent.shape[1]: + if verbose and has_tqdm: + progress_bar = tqdm(range(transformer_model.max_seq_len, seq_len + 1)) + else: + progress_bar = iter(range(transformer_model.max_seq_len, seq_len + 1)) + + for i in progress_bar: + idx_cond = latent[:, i + 1 - transformer_model.max_seq_len : i + 1] + # forward the model to get the logits for the index in the sequence + logits = transformer_model(x=idx_cond, context=condition) + # pluck the logits at the final step + logits = logits[:, -1, :] + # apply softmax to convert logits to (normalized) probabilities + p = F.softmax(logits, dim=-1) + # select correct values and append + p = torch.gather(p, 1, idx_cond[:, -1].unsqueeze(1)) + probs = torch.cat((probs, p), dim=1) + + # remove starting token probability probs = probs[:, 1:] probs = torch.log(probs) diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py index 972885f6..87766811 100644 --- a/tests/test_vqvaetransformer_inferer.py +++ b/tests/test_vqvaetransformer_inferer.py @@ -97,6 +97,30 @@ def test_prediction_shape( prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) self.assertEqual(prediction.shape, logits_shape) + @parameterized.expand(TEST_CASES) + def test_prediction_shape_shorter_sequence( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + max_seq_len = 3 + stage_2_params_shorter = {k: v for k, v in stage_2_params.items()} + stage_2_params_shorter["max_seq_len"] = max_seq_len + stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) + cropped_logits_shape = (logits_shape[0], max_seq_len, logits_shape[2]) + self.assertEqual(prediction.shape, cropped_logits_shape) + def test_sample(self): stage_1 = VQVAE( spatial_dims=2, @@ -139,6 +163,48 @@ def test_sample(self): ) self.assertEqual(sample.shape, (2, 1, 8, 8)) + def test_sample_shorter_sequence(self): + stage_1 = VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(8, 8), + num_res_channels=(8, 8), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + num_res_layers=1, + num_embeddings=16, + embedding_dim=8, + ) + stage_2 = DecoderOnlyTransformer( + num_tokens=16 + 1, + max_seq_len=3, + attn_layers_dim=4, + attn_layers_depth=2, + attn_layers_heads=1, + with_cross_attention=False, + ) + ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2)) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + inferer = VQVAETransformerInferer() + + starting_token = 16 # from stage_1 num_embeddings + + sample = inferer.sample( + latent_spatial_dim=(2, 2), + starting_tokens=starting_token * torch.ones((2, 1), device=device), + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + ) + self.assertEqual(sample.shape, (2, 1, 8, 8)) + @parameterized.expand(TEST_CASES) def test_get_likelihood( self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape @@ -161,6 +227,31 @@ def test_get_likelihood( ) self.assertEqual(likelihood.shape, latent_shape) + @parameterized.expand(TEST_CASES) + def test_get_likelihood_shorter_sequence( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + max_seq_len = 3 + stage_2_params_shorter = {k: v for k, v in stage_2_params.items()} + stage_2_params_shorter["max_seq_len"] = max_seq_len + stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood( + inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering + ) + self.assertEqual(likelihood.shape, latent_shape) + @parameterized.expand(TEST_CASES) def test_get_likelihood_resampling( self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape