From cd7b56cd6411397391bcf66367cd8b9c94e9cb93 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Wed, 8 Feb 2023 08:09:41 +0000 Subject: [PATCH 1/8] [WIP] Add inferer Signed-off-by: Walter Hugo Lopez Pinaya --- generative/inferers/inferer.py | 49 ++++++++ tests/test_vqvaetransformer_inferer.py | 165 +++++++++++++++++++++++++ 2 files changed, 214 insertions(+) create mode 100644 tests/test_vqvaetransformer_inferer.py diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index f65bdb20..71787d72 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -416,3 +416,52 @@ def get_likelihood( intermediates = [resizer(x) for x in intermediates] outputs = (outputs[0], intermediates) return outputs + + +class VQVAETransformerInferer(Inferer): + """ + """ + + def __init__(self) -> None: + Inferer.__init__(self) + + def __call__( + self, + inputs: torch.Tensor, + vqvae_model: Callable[..., torch.Tensor], + transformer_model: Callable[..., torch.Tensor], + condition: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: input image to which the latent representation will be extracted and noise is added. + vqvae_model: first stage model. + transformer_model: transformer model. + condition: conditioning for network input. + """ + with torch.no_grad(): + latent = vqvae_model.encode_stage_2_inputs(inputs) + + prediction = transformer_model(x=latent, context=condition) + + return prediction + + def sample( + self, + vqvae_model: Callable[..., torch.Tensor], + transformer_model: Callable[..., torch.Tensor], + conditioning: Optional[torch.Tensor] = None, + verbose: Optional[bool] = True, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: + """ + Sampling function for the VQVAE + Transformer model. + + Args: + vqvae_model: first stage model. + transformer_model: model to sample from. + conditioning: Conditioning for network input. + verbose: if true, prints the progression bar of the sampling process. + """ + pass diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py new file mode 100644 index 00000000..ff98a11e --- /dev/null +++ b/tests/test_vqvaetransformer_inferer.py @@ -0,0 +1,165 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from generative.inferers import DiffusionInferer +from generative.networks.nets import DiffusionModelUNet +from generative.networks.schedulers import DDIMScheduler, DDPMScheduler + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8, 8), + ], +] + + +class TestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_call(self, model_params, input_shape): + + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer(inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(TEST_CASES) + def test_sample_intermediates(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_ddpm_sampler(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_ddim_sampler(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_sampler_conditioned(self, model_params, input_shape): + model_params["with_conditioning"] = True + model_params["cross_attention_dim"] = 3 + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + conditioning = torch.randn([input_shape[0], 1, 3]).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_get_likelihood(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + likelihood, intermediates = inferer.get_likelihood( + inputs=input, diffusion_model=model, scheduler=scheduler, save_intermediates=True + ) + self.assertEqual(intermediates[0].shape, input.shape) + self.assertEqual(likelihood.shape[0], input.shape[0]) + + def test_normal_cdf(self): + from scipy.stats import norm + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + + x = torch.linspace(-10, 10, 20) + cdf_approx = inferer._approx_standard_normal_cdf(x) + cdf_true = norm.cdf(x) + torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main() From 4f22bd86f8560a99224c3630757b53242a65bdb8 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 11 Feb 2023 10:34:30 +0000 Subject: [PATCH 2/8] [WIP] Add sample method Signed-off-by: Walter Hugo Lopez Pinaya --- generative/inferers/inferer.py | 43 ++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 71787d72..16ee1086 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -17,6 +17,7 @@ import torch.nn as nn from monai.inferers import Inferer from monai.utils import optional_import +import torch.nn.functional as F tqdm, has_tqdm = optional_import("tqdm", name="tqdm") @@ -420,6 +421,7 @@ def get_likelihood( class VQVAETransformerInferer(Inferer): """ + Class to perform inference with a VQVAE + Transformer model. """ def __init__(self) -> None: @@ -438,7 +440,7 @@ def __call__( Args: inputs: input image to which the latent representation will be extracted and noise is added. vqvae_model: first stage model. - transformer_model: transformer model. + transformer_model: autoregressive transformer model. condition: conditioning for network input. """ with torch.no_grad(): @@ -448,11 +450,14 @@ def __call__( return prediction + @torch.no_grad() def sample( self, vqvae_model: Callable[..., torch.Tensor], transformer_model: Callable[..., torch.Tensor], conditioning: Optional[torch.Tensor] = None, + temperature: float = 1.0, + top_k: int | None =None, verbose: Optional[bool] = True, ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: """ @@ -462,6 +467,40 @@ def sample( vqvae_model: first stage model. transformer_model: model to sample from. conditioning: Conditioning for network input. + temperature: temperature for sampling. + top_k: top k sampling. verbose: if true, prints the progression bar of the sampling process. """ - pass + # TODO: define number of steps based on the size of the image + steps = 100 + latent = [] + if verbose and has_tqdm: + progress_bar = tqdm(steps) + else: + progress_bar = iter(steps) + + # start_ids = encode(start) + # x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) + # + + for _ in range(steps): + # if the sequence context is growing too long we must crop it at block_size + # idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] + # forward the model to get the logits for the index in the sequence + logits = transformer_model(x=latent, context=conditioning) + # pluck the logits at the final step and scale by desired temperature + logits = logits[:, -1, :] / temperature + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float('Inf') + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + # either sample from the distribution or take the most likely element + # sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) + # append sampled index to the running sequence and continue + idx = torch.cat((idx, idx_next), dim=1) + + image = vqvae_model.decode_stage_2_outputs(latent) + return image From 17704c54ccbcb12c17408ceb5251954bc33ca3b6 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 11 Feb 2023 11:29:47 +0000 Subject: [PATCH 3/8] Add ordering and complete Inferer methods Signed-off-by: Walter Hugo Lopez Pinaya --- generative/inferers/inferer.py | 52 +++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 16ee1086..8cea2732 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -11,7 +11,7 @@ import math -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union, Sequence import torch import torch.nn as nn @@ -432,6 +432,8 @@ def __call__( inputs: torch.Tensor, vqvae_model: Callable[..., torch.Tensor], transformer_model: Callable[..., torch.Tensor], + ordering: Callable[..., torch.Tensor], + starting_token: int, condition: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ @@ -441,10 +443,19 @@ def __call__( inputs: input image to which the latent representation will be extracted and noise is added. vqvae_model: first stage model. transformer_model: autoregressive transformer model. + ordering: ordering of the quantised latent representation. + starting_token: token to start the sequence to be inputted in the transformer model, the "Begin Of Sentence" + (BOS) token. condition: conditioning for network input. """ with torch.no_grad(): - latent = vqvae_model.encode_stage_2_inputs(inputs) + latent = vqvae_model.index_quantize(inputs) + + latent = latent.reshape(latent.shape[0], -1) + latent = latent[:, ordering.get_sequence_ordering()] + + latent = F.pad(latent, (1, 0), "constant", starting_token) + latent = latent.long() prediction = transformer_model(x=latent, context=condition) @@ -453,8 +464,11 @@ def __call__( @torch.no_grad() def sample( self, + sampled_image_shape: Sequence[int, int, int] | Sequence[int, int], + starting_tokens: torch.Tensor, vqvae_model: Callable[..., torch.Tensor], transformer_model: Callable[..., torch.Tensor], + ordering: Callable[..., torch.Tensor], conditioning: Optional[torch.Tensor] = None, temperature: float = 1.0, top_k: int | None =None, @@ -464,6 +478,8 @@ def sample( Sampling function for the VQVAE + Transformer model. Args: + sampled_image_shape: shape of the sampled image. + starting_tokens: starting tokens for the sampling. vqvae_model: first stage model. transformer_model: model to sample from. conditioning: Conditioning for network input. @@ -471,23 +487,23 @@ def sample( top_k: top k sampling. verbose: if true, prints the progression bar of the sampling process. """ - # TODO: define number of steps based on the size of the image - steps = 100 - latent = [] + seq_len = math.prod(sampled_image_shape) + if verbose and has_tqdm: - progress_bar = tqdm(steps) + progress_bar = tqdm(seq_len) else: - progress_bar = iter(steps) - - # start_ids = encode(start) - # x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) - # + progress_bar = iter(seq_len) - for _ in range(steps): + latent_seq = starting_tokens + for _ in progress_bar: # if the sequence context is growing too long we must crop it at block_size - # idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] + if latent_seq.size(1) <= transformer_model.max_seq_len: + idx_cond = latent_seq + else: + idx_cond = latent_seq[:, -transformer_model.max_seq_len:] + # forward the model to get the logits for the index in the sequence - logits = transformer_model(x=latent, context=conditioning) + logits = transformer_model(x=idx_cond, context=conditioning) # pluck the logits at the final step and scale by desired temperature logits = logits[:, -1, :] / temperature # optionally crop the logits to only the top k options @@ -500,7 +516,9 @@ def sample( # sample from the distribution idx_next = torch.multinomial(probs, num_samples=1) # append sampled index to the running sequence and continue - idx = torch.cat((idx, idx_next), dim=1) + latent_seq = torch.cat((latent_seq, idx_next), dim=1) + + latent_seq = latent_seq[:, 1:] + latent = latent_seq.view(-1, ordering.get_revert_sequence_ordering()) - image = vqvae_model.decode_stage_2_outputs(latent) - return image + return vqvae_model.decode_stage_2_outputs(latent) From 04d8f4a290f37f8d81c01b5342e8b7745e9e7686 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 11 Feb 2023 19:44:02 +0000 Subject: [PATCH 4/8] Update inferer Signed-off-by: Walter Hugo Lopez Pinaya --- generative/inferers/inferer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index c2c8018b..82f4fc20 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -16,9 +16,9 @@ import torch import torch.nn as nn +import torch.nn.functional as F from monai.inferers import Inferer from monai.utils import optional_import -import torch.nn.functional as F tqdm, has_tqdm = optional_import("tqdm", name="tqdm") @@ -469,7 +469,7 @@ def sample( ordering: Callable[..., torch.Tensor], conditioning: Optional[torch.Tensor] = None, temperature: float = 1.0, - top_k: int | None =None, + top_k: int | None = None, verbose: Optional[bool] = True, ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: """ @@ -498,7 +498,7 @@ def sample( if latent_seq.size(1) <= transformer_model.max_seq_len: idx_cond = latent_seq else: - idx_cond = latent_seq[:, -transformer_model.max_seq_len:] + idx_cond = latent_seq[:, -transformer_model.max_seq_len :] # forward the model to get the logits for the index in the sequence logits = transformer_model(x=idx_cond, context=conditioning) @@ -507,7 +507,7 @@ def sample( # optionally crop the logits to only the top k options if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - logits[logits < v[:, [-1]]] = -float('Inf') + logits[logits < v[:, [-1]]] = -float("Inf") # apply softmax to convert logits to (normalized) probabilities probs = F.softmax(logits, dim=-1) # either sample from the distribution or take the most likely element From dcf417e28fbf5ff8f6b313d46aa347272f7b2a25 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 11 Feb 2023 20:37:37 +0000 Subject: [PATCH 5/8] Add test_prediction_shape Signed-off-by: Walter Hugo Lopez Pinaya --- generative/inferers/__init__.py | 2 +- generative/inferers/inferer.py | 10 +- tests/test_vqvaetransformer_inferer.py | 186 +++++++++---------------- 3 files changed, 73 insertions(+), 125 deletions(-) diff --git a/generative/inferers/__init__.py b/generative/inferers/__init__.py index 94775e76..e6402093 100644 --- a/generative/inferers/__init__.py +++ b/generative/inferers/__init__.py @@ -11,4 +11,4 @@ from __future__ import annotations -from .inferer import DiffusionInferer, LatentDiffusionInferer +from .inferer import DiffusionInferer, LatentDiffusionInferer, VQVAETransformerInferer diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 82f4fc20..f808a6cd 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -12,7 +12,7 @@ from __future__ import annotations import math -from collections.abc import Callable +from collections.abc import Callable, Sequence import torch import torch.nn as nn @@ -432,7 +432,7 @@ def __call__( transformer_model: Callable[..., torch.Tensor], ordering: Callable[..., torch.Tensor], starting_token: int, - condition: Optional[torch.Tensor] = None, + condition: torch.Tensor | None = None, ) -> torch.Tensor: """ Implements the forward pass for a supervised training iteration. @@ -467,11 +467,11 @@ def sample( vqvae_model: Callable[..., torch.Tensor], transformer_model: Callable[..., torch.Tensor], ordering: Callable[..., torch.Tensor], - conditioning: Optional[torch.Tensor] = None, + conditioning: torch.Tensor | None = None, temperature: float = 1.0, top_k: int | None = None, - verbose: Optional[bool] = True, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: + verbose: bool | None = True, + ) -> torch.Tensor: """ Sampling function for the VQVAE + Transformer model. diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py index ff98a11e..b39daa8e 100644 --- a/tests/test_vqvaetransformer_inferer.py +++ b/tests/test_vqvaetransformer_inferer.py @@ -9,14 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch from parameterized import parameterized -from generative.inferers import DiffusionInferer -from generative.networks.nets import DiffusionModelUNet -from generative.networks.schedulers import DDIMScheduler, DDPMScheduler +from generative.inferers import VQVAETransformerInferer +from generative.networks.nets import VQVAE, DecoderOnlyTransformer +from generative.utils.ordering import Ordering, OrderingType TEST_CASES = [ [ @@ -24,141 +26,87 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": [8], - "norm_num_groups": 8, - "attention_levels": [True], - "num_res_blocks": 1, - "num_head_channels": 8, + "num_levels": 2, + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_res_layers": 1, + "num_channels": 8, + "num_res_channels": [8, 8], + "num_embeddings": 16, + "embedding_dim": 8, + }, + { + "num_tokens": 16 + 1, + "max_seq_len": 4 + 1, + "attn_layers_dim": 4, + "attn_layers_depth": 2, + "attn_layers_heads": 1, + "with_cross_attention": False, + }, + { + "ordering_type": OrderingType.RASTER_SCAN.value, + "spatial_dims": 2, + "dimensions": (2, 2, 2), }, (2, 1, 8, 8), + (2, 5, 17), ], [ { "spatial_dims": 3, "in_channels": 1, "out_channels": 1, - "num_channels": [8], - "norm_num_groups": 8, - "attention_levels": [True], - "num_res_blocks": 1, - "num_head_channels": 8, + "num_levels": 2, + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_res_layers": 1, + "num_channels": 8, + "num_res_channels": [8, 8], + "num_embeddings": 16, + "embedding_dim": 8, + }, + { + "num_tokens": 16 + 1, + "max_seq_len": 9 + 1, + "attn_layers_dim": 4, + "attn_layers_depth": 2, + "attn_layers_heads": 1, + "with_cross_attention": False, + }, + { + "ordering_type": OrderingType.RASTER_SCAN.value, + "spatial_dims": 3, + "dimensions": (2, 2, 2, 2), }, (2, 1, 8, 8, 8), + (2, 9, 17), ], ] - -class TestDiffusionSamplingInferer(unittest.TestCase): +class TestVQVAETransformerInferer(unittest.TestCase): @parameterized.expand(TEST_CASES) - def test_call(self, model_params, input_shape): + def test_prediction_shape(self, stage_1_params, stage_2_params, ordering_params, input_shape, latent_shape): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) - model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - input = torch.randn(input_shape).to(device) - noise = torch.randn(input_shape).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = DiffusionInferer(scheduler=scheduler) - scheduler.set_timesteps(num_inference_steps=10) - timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() - sample = inferer(inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps) - self.assertEqual(sample.shape, input_shape) + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() - @parameterized.expand(TEST_CASES) - def test_sample_intermediates(self, model_params, input_shape): - model = DiffusionModelUNet(**model_params) - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - noise = torch.randn(input_shape).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = DiffusionInferer(scheduler=scheduler) - scheduler.set_timesteps(num_inference_steps=10) - sample, intermediates = inferer.sample( - input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 - ) - self.assertEqual(len(intermediates), 10) - - @parameterized.expand(TEST_CASES) - def test_ddpm_sampler(self, model_params, input_shape): - model = DiffusionModelUNet(**model_params) - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - noise = torch.randn(input_shape).to(device) - scheduler = DDPMScheduler(num_train_timesteps=1000) - inferer = DiffusionInferer(scheduler=scheduler) - scheduler.set_timesteps(num_inference_steps=10) - sample, intermediates = inferer.sample( - input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 - ) - self.assertEqual(len(intermediates), 10) - - @parameterized.expand(TEST_CASES) - def test_ddim_sampler(self, model_params, input_shape): - model = DiffusionModelUNet(**model_params) - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - noise = torch.randn(input_shape).to(device) - scheduler = DDIMScheduler(num_train_timesteps=1000) - inferer = DiffusionInferer(scheduler=scheduler) - scheduler.set_timesteps(num_inference_steps=10) - sample, intermediates = inferer.sample( - input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 - ) - self.assertEqual(len(intermediates), 10) - - @parameterized.expand(TEST_CASES) - def test_sampler_conditioned(self, model_params, input_shape): - model_params["with_conditioning"] = True - model_params["cross_attention_dim"] = 3 - model = DiffusionModelUNet(**model_params) - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - noise = torch.randn(input_shape).to(device) - scheduler = DDIMScheduler(num_train_timesteps=1000) - inferer = DiffusionInferer(scheduler=scheduler) - scheduler.set_timesteps(num_inference_steps=10) - conditioning = torch.randn([input_shape[0], 1, 3]).to(device) - sample, intermediates = inferer.sample( - input_noise=noise, - diffusion_model=model, - scheduler=scheduler, - save_intermediates=True, - intermediate_steps=1, - conditioning=conditioning, - ) - self.assertEqual(len(intermediates), 10) - - @parameterized.expand(TEST_CASES) - def test_get_likelihood(self, model_params, input_shape): - model = DiffusionModelUNet(**model_params) - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() input = torch.randn(input_shape).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = DiffusionInferer(scheduler=scheduler) - scheduler.set_timesteps(num_inference_steps=10) - likelihood, intermediates = inferer.get_likelihood( - inputs=input, diffusion_model=model, scheduler=scheduler, save_intermediates=True - ) - self.assertEqual(intermediates[0].shape, input.shape) - self.assertEqual(likelihood.shape[0], input.shape[0]) - - def test_normal_cdf(self): - from scipy.stats import norm - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = DiffusionInferer(scheduler=scheduler) - - x = torch.linspace(-10, 10, 20) - cdf_approx = inferer._approx_standard_normal_cdf(x) - cdf_true = norm.cdf(x) - torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + inferer = VQVAETransformerInferer() + prediction = inferer( + inputs=input, + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + starting_token=16, # from stage_1 num_embeddings + ) + self.assertEqual(prediction.shape, latent_shape) if __name__ == "__main__": From 4be6e611400f64a99f25e37fcdc5cc0620cfe6a3 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 11 Feb 2023 21:38:48 +0000 Subject: [PATCH 6/8] Add test_sample Signed-off-by: Walter Hugo Lopez Pinaya --- generative/inferers/inferer.py | 23 ++++++----- tests/test_vqvaetransformer_inferer.py | 56 +++++++++++++++++++++----- 2 files changed, 59 insertions(+), 20 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index f808a6cd..e87d131d 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -443,7 +443,7 @@ def __call__( transformer_model: autoregressive transformer model. ordering: ordering of the quantised latent representation. starting_token: token to start the sequence to be inputted in the transformer model, the "Begin Of Sentence" - (BOS) token. + (BOS) token. It must be vqvae_model.num_embeddings value. condition: conditioning for network input. """ with torch.no_grad(): @@ -462,7 +462,7 @@ def __call__( @torch.no_grad() def sample( self, - sampled_image_shape: Sequence[int, int, int] | Sequence[int, int], + latent_spatial_dim: Sequence[int, int, int] | Sequence[int, int], starting_tokens: torch.Tensor, vqvae_model: Callable[..., torch.Tensor], transformer_model: Callable[..., torch.Tensor], @@ -476,8 +476,8 @@ def sample( Sampling function for the VQVAE + Transformer model. Args: - sampled_image_shape: shape of the sampled image. - starting_tokens: starting tokens for the sampling. + latent_spatial_dim: shape of the sampled image. + starting_tokens: starting tokens for the sampling. It must be vqvae_model.num_embeddings value. vqvae_model: first stage model. transformer_model: model to sample from. conditioning: Conditioning for network input. @@ -485,14 +485,14 @@ def sample( top_k: top k sampling. verbose: if true, prints the progression bar of the sampling process. """ - seq_len = math.prod(sampled_image_shape) + seq_len = math.prod(latent_spatial_dim) if verbose and has_tqdm: - progress_bar = tqdm(seq_len) + progress_bar = tqdm(range(seq_len)) else: - progress_bar = iter(seq_len) + progress_bar = iter(range(seq_len)) - latent_seq = starting_tokens + latent_seq = starting_tokens.long() for _ in progress_bar: # if the sequence context is growing too long we must crop it at block_size if latent_seq.size(1) <= transformer_model.max_seq_len: @@ -510,6 +510,8 @@ def sample( logits[logits < v[:, [-1]]] = -float("Inf") # apply softmax to convert logits to (normalized) probabilities probs = F.softmax(logits, dim=-1) + # remove the chance to be sampled the BOS token + probs[:, vqvae_model.num_embeddings] = 0 # either sample from the distribution or take the most likely element # sample from the distribution idx_next = torch.multinomial(probs, num_samples=1) @@ -517,6 +519,7 @@ def sample( latent_seq = torch.cat((latent_seq, idx_next), dim=1) latent_seq = latent_seq[:, 1:] - latent = latent_seq.view(-1, ordering.get_revert_sequence_ordering()) + latent_seq = latent_seq[:, ordering.get_revert_sequence_ordering()] + latent = latent_seq.reshape((starting_tokens.shape[0],) + latent_spatial_dim) - return vqvae_model.decode_stage_2_outputs(latent) + return vqvae_model.decode_samples(latent) diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py index b39daa8e..9c49a304 100644 --- a/tests/test_vqvaetransformer_inferer.py +++ b/tests/test_vqvaetransformer_inferer.py @@ -43,11 +43,7 @@ "attn_layers_heads": 1, "with_cross_attention": False, }, - { - "ordering_type": OrderingType.RASTER_SCAN.value, - "spatial_dims": 2, - "dimensions": (2, 2, 2), - }, + {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 2, "dimensions": (2, 2, 2)}, (2, 1, 8, 8), (2, 5, 17), ], @@ -73,16 +69,13 @@ "attn_layers_heads": 1, "with_cross_attention": False, }, - { - "ordering_type": OrderingType.RASTER_SCAN.value, - "spatial_dims": 3, - "dimensions": (2, 2, 2, 2), - }, + {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 3, "dimensions": (2, 2, 2, 2)}, (2, 1, 8, 8, 8), (2, 9, 17), ], ] + class TestVQVAETransformerInferer(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_prediction_shape(self, stage_1_params, stage_2_params, ordering_params, input_shape, latent_shape): @@ -108,6 +101,49 @@ def test_prediction_shape(self, stage_1_params, stage_2_params, ordering_params, ) self.assertEqual(prediction.shape, latent_shape) + def test_sample(self): + stage_1 = VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_levels=2, + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + num_res_layers=1, + num_channels=8, + num_res_channels=(8, 8), + num_embeddings=16, + embedding_dim=8, + ) + stage_2 = DecoderOnlyTransformer( + num_tokens=16 + 1, + max_seq_len=4 + 1, + attn_layers_dim=4, + attn_layers_depth=2, + attn_layers_heads=1, + with_cross_attention=False, + ) + ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2)) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + inferer = VQVAETransformerInferer() + + starting_token = 16 # from stage_1 num_embeddings + + sample = inferer.sample( + latent_spatial_dim=(2, 2), + starting_tokens=starting_token * torch.ones((2, 1), device=device), + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + ) + self.assertEqual(sample.shape, (2, 1, 8, 8)) + if __name__ == "__main__": unittest.main() From 7fdc10cbe636c5801f89dd40f7340e8e822a3d83 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 13 Feb 2023 21:35:00 +0000 Subject: [PATCH 7/8] Fix comments and docstring Signed-off-by: Walter Hugo Lopez Pinaya --- generative/inferers/inferer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index e87d131d..f3edc119 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -438,7 +438,7 @@ def __call__( Implements the forward pass for a supervised training iteration. Args: - inputs: input image to which the latent representation will be extracted and noise is added. + inputs: input image to which the latent representation will be extracted. vqvae_model: first stage model. transformer_model: autoregressive transformer model. ordering: ordering of the quantised latent representation. @@ -512,7 +512,6 @@ def sample( probs = F.softmax(logits, dim=-1) # remove the chance to be sampled the BOS token probs[:, vqvae_model.num_embeddings] = 0 - # either sample from the distribution or take the most likely element # sample from the distribution idx_next = torch.multinomial(probs, num_samples=1) # append sampled index to the running sequence and continue From e5c547edf10d4d5c84e32c9a9ca602f4af1e130d Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 13 Feb 2023 21:43:33 +0000 Subject: [PATCH 8/8] Remove starting_token from __call__ Signed-off-by: Walter Hugo Lopez Pinaya --- generative/inferers/inferer.py | 7 +++---- tests/test_vqvaetransformer_inferer.py | 8 +------- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index f3edc119..f55528a6 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -431,7 +431,6 @@ def __call__( vqvae_model: Callable[..., torch.Tensor], transformer_model: Callable[..., torch.Tensor], ordering: Callable[..., torch.Tensor], - starting_token: int, condition: torch.Tensor | None = None, ) -> torch.Tensor: """ @@ -442,8 +441,6 @@ def __call__( vqvae_model: first stage model. transformer_model: autoregressive transformer model. ordering: ordering of the quantised latent representation. - starting_token: token to start the sequence to be inputted in the transformer model, the "Begin Of Sentence" - (BOS) token. It must be vqvae_model.num_embeddings value. condition: conditioning for network input. """ with torch.no_grad(): @@ -452,7 +449,9 @@ def __call__( latent = latent.reshape(latent.shape[0], -1) latent = latent[:, ordering.get_sequence_ordering()] - latent = F.pad(latent, (1, 0), "constant", starting_token) + # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token. + # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens. + latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) latent = latent.long() prediction = transformer_model(x=latent, context=condition) diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py index 9c49a304..edb152a3 100644 --- a/tests/test_vqvaetransformer_inferer.py +++ b/tests/test_vqvaetransformer_inferer.py @@ -92,13 +92,7 @@ def test_prediction_shape(self, stage_1_params, stage_2_params, ordering_params, input = torch.randn(input_shape).to(device) inferer = VQVAETransformerInferer() - prediction = inferer( - inputs=input, - vqvae_model=stage_1, - transformer_model=stage_2, - ordering=ordering, - starting_token=16, # from stage_1 num_embeddings - ) + prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) self.assertEqual(prediction.shape, latent_shape) def test_sample(self):