diff --git a/generative/inferers/__init__.py b/generative/inferers/__init__.py index 94775e76..e6402093 100644 --- a/generative/inferers/__init__.py +++ b/generative/inferers/__init__.py @@ -11,4 +11,4 @@ from __future__ import annotations -from .inferer import DiffusionInferer, LatentDiffusionInferer +from .inferer import DiffusionInferer, LatentDiffusionInferer, VQVAETransformerInferer diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 3b4f6547..f55528a6 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -12,10 +12,11 @@ from __future__ import annotations import math -from collections.abc import Callable +from collections.abc import Callable, Sequence import torch import torch.nn as nn +import torch.nn.functional as F from monai.inferers import Inferer from monai.utils import optional_import @@ -414,3 +415,109 @@ def get_likelihood( intermediates = [resizer(x) for x in intermediates] outputs = (outputs[0], intermediates) return outputs + + +class VQVAETransformerInferer(Inferer): + """ + Class to perform inference with a VQVAE + Transformer model. + """ + + def __init__(self) -> None: + Inferer.__init__(self) + + def __call__( + self, + inputs: torch.Tensor, + vqvae_model: Callable[..., torch.Tensor], + transformer_model: Callable[..., torch.Tensor], + ordering: Callable[..., torch.Tensor], + condition: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: input image to which the latent representation will be extracted. + vqvae_model: first stage model. + transformer_model: autoregressive transformer model. + ordering: ordering of the quantised latent representation. + condition: conditioning for network input. + """ + with torch.no_grad(): + latent = vqvae_model.index_quantize(inputs) + + latent = latent.reshape(latent.shape[0], -1) + latent = latent[:, ordering.get_sequence_ordering()] + + # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token. + # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens. + latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) + latent = latent.long() + + prediction = transformer_model(x=latent, context=condition) + + return prediction + + @torch.no_grad() + def sample( + self, + latent_spatial_dim: Sequence[int, int, int] | Sequence[int, int], + starting_tokens: torch.Tensor, + vqvae_model: Callable[..., torch.Tensor], + transformer_model: Callable[..., torch.Tensor], + ordering: Callable[..., torch.Tensor], + conditioning: torch.Tensor | None = None, + temperature: float = 1.0, + top_k: int | None = None, + verbose: bool | None = True, + ) -> torch.Tensor: + """ + Sampling function for the VQVAE + Transformer model. + + Args: + latent_spatial_dim: shape of the sampled image. + starting_tokens: starting tokens for the sampling. It must be vqvae_model.num_embeddings value. + vqvae_model: first stage model. + transformer_model: model to sample from. + conditioning: Conditioning for network input. + temperature: temperature for sampling. + top_k: top k sampling. + verbose: if true, prints the progression bar of the sampling process. + """ + seq_len = math.prod(latent_spatial_dim) + + if verbose and has_tqdm: + progress_bar = tqdm(range(seq_len)) + else: + progress_bar = iter(range(seq_len)) + + latent_seq = starting_tokens.long() + for _ in progress_bar: + # if the sequence context is growing too long we must crop it at block_size + if latent_seq.size(1) <= transformer_model.max_seq_len: + idx_cond = latent_seq + else: + idx_cond = latent_seq[:, -transformer_model.max_seq_len :] + + # forward the model to get the logits for the index in the sequence + logits = transformer_model(x=idx_cond, context=conditioning) + # pluck the logits at the final step and scale by desired temperature + logits = logits[:, -1, :] / temperature + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float("Inf") + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + # remove the chance to be sampled the BOS token + probs[:, vqvae_model.num_embeddings] = 0 + # sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) + # append sampled index to the running sequence and continue + latent_seq = torch.cat((latent_seq, idx_next), dim=1) + + latent_seq = latent_seq[:, 1:] + latent_seq = latent_seq[:, ordering.get_revert_sequence_ordering()] + latent = latent_seq.reshape((starting_tokens.shape[0],) + latent_spatial_dim) + + return vqvae_model.decode_samples(latent) diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py new file mode 100644 index 00000000..edb152a3 --- /dev/null +++ b/tests/test_vqvaetransformer_inferer.py @@ -0,0 +1,143 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from generative.inferers import VQVAETransformerInferer +from generative.networks.nets import VQVAE, DecoderOnlyTransformer +from generative.utils.ordering import Ordering, OrderingType + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_levels": 2, + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_res_layers": 1, + "num_channels": 8, + "num_res_channels": [8, 8], + "num_embeddings": 16, + "embedding_dim": 8, + }, + { + "num_tokens": 16 + 1, + "max_seq_len": 4 + 1, + "attn_layers_dim": 4, + "attn_layers_depth": 2, + "attn_layers_heads": 1, + "with_cross_attention": False, + }, + {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 2, "dimensions": (2, 2, 2)}, + (2, 1, 8, 8), + (2, 5, 17), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_levels": 2, + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_res_layers": 1, + "num_channels": 8, + "num_res_channels": [8, 8], + "num_embeddings": 16, + "embedding_dim": 8, + }, + { + "num_tokens": 16 + 1, + "max_seq_len": 9 + 1, + "attn_layers_dim": 4, + "attn_layers_depth": 2, + "attn_layers_heads": 1, + "with_cross_attention": False, + }, + {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 3, "dimensions": (2, 2, 2, 2)}, + (2, 1, 8, 8, 8), + (2, 9, 17), + ], +] + + +class TestVQVAETransformerInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_prediction_shape(self, stage_1_params, stage_2_params, ordering_params, input_shape, latent_shape): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) + self.assertEqual(prediction.shape, latent_shape) + + def test_sample(self): + stage_1 = VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_levels=2, + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + num_res_layers=1, + num_channels=8, + num_res_channels=(8, 8), + num_embeddings=16, + embedding_dim=8, + ) + stage_2 = DecoderOnlyTransformer( + num_tokens=16 + 1, + max_seq_len=4 + 1, + attn_layers_dim=4, + attn_layers_depth=2, + attn_layers_heads=1, + with_cross_attention=False, + ) + ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2)) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + inferer = VQVAETransformerInferer() + + starting_token = 16 # from stage_1 num_embeddings + + sample = inferer.sample( + latent_spatial_dim=(2, 2), + starting_tokens=starting_token * torch.ones((2, 1), device=device), + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + ) + self.assertEqual(sample.shape, (2, 1, 8, 8)) + + +if __name__ == "__main__": + unittest.main()