From 0bef5d07271e0056473a3e97a45285778b4b1f92 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 11 Feb 2023 14:43:38 +0000 Subject: [PATCH 01/15] [WIP] Remove num_levels Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/vqvae.py | 52 +++++++++---------- tests/test_vqvae.py | 85 ++++++++++++++++--------------- 2 files changed, 70 insertions(+), 67 deletions(-) diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index feceed13..471a1666 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -103,7 +103,6 @@ class VQVAE(nn.Module): spatial_dims: number of spatial spatial_dims. in_channels: number of input channels. out_channels: number of output channels. - num_levels: number of levels that the network has. downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the following information stride (int), kernel_size (int), dilation (int) and padding (int). upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the @@ -131,16 +130,15 @@ def __init__( spatial_dims: int, in_channels: int, out_channels: int, - num_levels: int = 3, + num_channels: Sequence[int] | int = (96, 96, 192), + num_res_layers: int = 3, + num_res_channels: Sequence[int] | int = (96, 96, 192), downsample_parameters: tuple[tuple[int, int, int, int], ...] = ((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1)), upsample_parameters: tuple[tuple[int, int, int, int, int], ...] = ( (2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0), ), - num_res_layers: int = 3, - num_channels: Sequence[int] | int = (96, 96, 192), - num_res_channels: Sequence[int] | int = (96, 96, 192), num_embeddings: int = 32, embedding_dim: int = 64, embedding_init: str = "normal", @@ -158,28 +156,30 @@ def __init__( self.in_channels = in_channels self.out_channels = out_channels self.spatial_dims = spatial_dims + self.num_channels = num_channels - if isinstance(num_channels, int): - num_channels = ensure_tuple_rep(num_channels, num_levels) if isinstance(num_res_channels, int): - num_res_channels = ensure_tuple_rep(num_res_channels, num_levels) - - assert ( - num_levels == len(downsample_parameters) - and num_levels == len(upsample_parameters) - and num_levels == len(num_channels) - and num_levels == len(num_res_channels) - ), ( - f"downsample_parameters, upsample_parameters, num_channels and num_res_channels must have the same number of" - f" elements as num_levels. But got {len(downsample_parameters)}, {len(upsample_parameters)}, " - f"{len(num_channels)} and {len(num_res_channels)} instead of {num_levels}." - ) + num_res_channels = ensure_tuple_rep(num_res_channels, len(num_channels)) + + if len(num_res_channels) != len(num_channels): + raise ValueError( + "`num_res_channels` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if len(downsample_parameters) != len(num_channels): + raise ValueError( + "`downsample_parameters` should be a tuple of tuples with the same length as `num_channels`." + ) + + if len(upsample_parameters) != len(num_channels): + raise ValueError( + "`upsample_parameters` should be a tuple of tuples with the same length as `num_channels`." + ) - self.num_levels = num_levels self.downsample_parameters = downsample_parameters self.upsample_parameters = upsample_parameters self.num_res_layers = num_res_layers - self.num_channels = num_channels self.num_res_channels = num_res_channels self.dropout = dropout @@ -204,7 +204,7 @@ def __init__( def construct_encoder(self) -> torch.nn.Sequential: encoder = [] - for idx in range(self.num_levels): + for idx in range(len(self.num_channels)): encoder.append( Convolution( spatial_dims=self.spatial_dims, @@ -303,7 +303,7 @@ def construct_decoder(self) -> torch.nn.Sequential: ) ] - for idx in range(self.num_levels): + for idx in range(len(self.num_channels)): for _ in range(self.num_res_layers): decoder.append( VQVAEResidualUnit( @@ -322,18 +322,18 @@ def construct_decoder(self) -> torch.nn.Sequential: Convolution( spatial_dims=self.spatial_dims, in_channels=decoder_num_channels[idx], - out_channels=self.out_channels if idx == self.num_levels - 1 else decoder_num_channels[idx + 1], + out_channels=self.out_channels if idx == len(self.num_channels) - 1 else decoder_num_channels[idx + 1], strides=self.upsample_parameters[idx][0], kernel_size=self.upsample_parameters[idx][1], adn_ordering=self.adn_ordering, act=self.act, - dropout=self.dropout if idx != self.num_levels - 1 else None, + dropout=self.dropout if idx != len(self.num_channels) - 1 else None, norm=None, dropout_dim=1, dilation=self.upsample_parameters[idx][2], groups=1, bias=True, - conv_only=idx == self.num_levels - 1, + conv_only=idx == len(self.num_channels) - 1, is_transposed=True, padding=self.upsample_parameters[idx][3], output_padding=self.upsample_parameters[idx][4], diff --git a/tests/test_vqvae.py b/tests/test_vqvae.py index 8c0226ea..10eed726 100644 --- a/tests/test_vqvae.py +++ b/tests/test_vqvae.py @@ -26,12 +26,11 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_levels": 2, - "downsample_parameters": [(2, 4, 1, 1)] * 2, - "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, + "num_channels": (8, 8), "num_res_layers": 1, - "num_channels": 8, "num_res_channels": [8, 8], + "downsample_parameters": [(2, 4, 1, 1)] * 2, + "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, "num_embeddings": 16, "embedding_dim": 8, "embedding_init": "normal", @@ -51,12 +50,11 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_levels": 2, - "downsample_parameters": [(2, 4, 1, 1)] * 2, - "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, + "num_channels": (8, 8), "num_res_layers": 1, - "num_channels": 8, "num_res_channels": 8, + "downsample_parameters": [(2, 4, 1, 1)] * 2, + "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, "num_embeddings": 16, "embedding_dim": 8, "embedding_init": "normal", @@ -76,12 +74,11 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_levels": 2, - "downsample_parameters": [(2, 4, 1, 1)] * 2, - "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, - "num_res_layers": 1, "num_channels": [8, 8], + "num_res_layers": 1, "num_res_channels": [8, 8], + "downsample_parameters": [(2, 4, 1, 1)] * 2, + "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, "num_embeddings": 16, "embedding_dim": 8, "embedding_init": "normal", @@ -101,7 +98,6 @@ "spatial_dims": 3, "in_channels": 1, "out_channels": 1, - "num_levels": 2, "downsample_parameters": [(2, 4, 1, 1)] * 2, "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, "num_res_layers": 1, @@ -123,34 +119,10 @@ ], ] -# 1-channel 2D, should fail because of number of levels, number of downsamplings, number of upsamplings mismatch. -TEST_CASE_FAIL = { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "num_levels": 3, - "downsample_parameters": [(2, 4, 1, 1)] * 2, - "upsample_parameters": [(2, 4, 1, 1, 0)] * 4, - "num_res_layers": 1, - "num_channels": [8, 8], - "num_res_channels": [8, 8], - "num_embeddings": 16, - "embedding_dim": 8, - "embedding_init": "normal", - "commitment_cost": 0.25, - "decay": 0.5, - "epsilon": 1e-5, - "adn_ordering": "NDA", - "dropout": 0.1, - "act": "RELU", - "output_act": None, -} - TEST_LATENT_SHAPE = { "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_levels": 2, "downsample_parameters": [(2, 4, 1, 1)] * 2, "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, "num_res_layers": 1, @@ -186,7 +158,6 @@ def test_script(self): spatial_dims=2, in_channels=1, out_channels=1, - num_levels=2, downsample_parameters=tuple([(2, 4, 1, 1)] * 2), upsample_parameters=tuple([(2, 4, 1, 1, 0)] * 2), num_res_layers=1, @@ -207,9 +178,41 @@ def test_script(self): test_data = torch.randn(1, 1, 16, 16) test_script_save(net, test_data) - def test_level_upsample_downsample_difference(self): - with self.assertRaises(AssertionError): - VQVAE(**TEST_CASE_FAIL) + def test_num_channels_not_same_size_of_num_res_channels(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters = ((2, 4, 1, 1),) * 2, + upsample_parameters = ((2, 4, 1, 1, 0),) * 2, + ) + + def test_num_channels_not_same_size_of_downsample_parameters(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters = ((2, 4, 1, 1),) * 3, + upsample_parameters = ((2, 4, 1, 1, 0),) * 2, + ) + + def test_num_channels_not_same_size_of_upsample_parameters(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters = ((2, 4, 1, 1),) * 2, + upsample_parameters = ((2, 4, 1, 1, 0),) * 3, + ) def test_encode_shape(self): device = "cuda" if torch.cuda.is_available() else "cpu" From 3bca01045c9a7d5a5c2bf65320887f2b1ca9d81c Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 11 Feb 2023 14:50:19 +0000 Subject: [PATCH 02/15] Remove num_levels Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/vqvae.py | 4 +++- tests/test_latent_diffusion_inferer.py | 7 +++---- tests/test_vqvae.py | 12 ++++++------ 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index 471a1666..41c16bfc 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -322,7 +322,9 @@ def construct_decoder(self) -> torch.nn.Sequential: Convolution( spatial_dims=self.spatial_dims, in_channels=decoder_num_channels[idx], - out_channels=self.out_channels if idx == len(self.num_channels) - 1 else decoder_num_channels[idx + 1], + out_channels=self.out_channels + if idx == len(self.num_channels) - 1 + else decoder_num_channels[idx + 1], strides=self.upsample_parameters[idx][0], kernel_size=self.upsample_parameters[idx][1], adn_ordering=self.adn_ordering, diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index 2d4dbb20..9623c76a 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -54,12 +54,11 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_levels": 2, - "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), - "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), - "num_res_layers": 1, "num_channels": [4, 4], + "num_res_layers": 1, "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), "num_embeddings": 16, "embedding_dim": 3, }, diff --git a/tests/test_vqvae.py b/tests/test_vqvae.py index 10eed726..264b743b 100644 --- a/tests/test_vqvae.py +++ b/tests/test_vqvae.py @@ -186,8 +186,8 @@ def test_num_channels_not_same_size_of_num_res_channels(self): out_channels=1, num_channels=(16, 16), num_res_channels=(16, 16, 16), - downsample_parameters = ((2, 4, 1, 1),) * 2, - upsample_parameters = ((2, 4, 1, 1, 0),) * 2, + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, ) def test_num_channels_not_same_size_of_downsample_parameters(self): @@ -198,8 +198,8 @@ def test_num_channels_not_same_size_of_downsample_parameters(self): out_channels=1, num_channels=(16, 16), num_res_channels=(16, 16), - downsample_parameters = ((2, 4, 1, 1),) * 3, - upsample_parameters = ((2, 4, 1, 1, 0),) * 2, + downsample_parameters=((2, 4, 1, 1),) * 3, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, ) def test_num_channels_not_same_size_of_upsample_parameters(self): @@ -210,8 +210,8 @@ def test_num_channels_not_same_size_of_upsample_parameters(self): out_channels=1, num_channels=(16, 16), num_res_channels=(16, 16), - downsample_parameters = ((2, 4, 1, 1),) * 2, - upsample_parameters = ((2, 4, 1, 1, 0),) * 3, + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 3, ) def test_encode_shape(self): From 079758d861a6936092ac50023ea85da890b0098d Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 11 Feb 2023 15:54:14 +0000 Subject: [PATCH 03/15] Add Encoder and Decoder classes Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/vqvae.py | 364 +++++++++++++++++------------- 1 file changed, 202 insertions(+), 162 deletions(-) diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index 41c16bfc..015b1375 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -92,6 +92,171 @@ def forward(self, x): return torch.nn.functional.relu(x + self.conv2(self.conv1(x)), True) +class Encoder(nn.Module): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_channels: Sequence[int], + num_res_layers: int, + num_res_channels: Sequence[int], + downsample_parameters: Sequence[Sequence[int, int, int, int], ...], + adn_ordering: str, + dropout: tuple | str | float | None, + act: tuple | str | None, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.num_channels = num_channels + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + self.downsample_parameters = downsample_parameters + self.adn_ordering = adn_ordering + self.dropout = dropout + self.act = act + + blocks = [] + + for i in range(len(self.num_channels)): + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.in_channels if i == 0 else self.num_channels[i - 1], + out_channels=self.num_channels[i], + strides=self.downsample_parameters[i][0], + kernel_size=self.downsample_parameters[i][1], + adn_ordering=self.adn_ordering, + act=self.act, + norm=None, + dropout=None if i == 0 else self.dropout, + dropout_dim=1, + dilation=self.downsample_parameters[i][2], + padding=self.downsample_parameters[i][3], + ) + ) + + for _ in range(self.num_res_layers): + blocks.append( + VQVAEResidualUnit( + spatial_dims=self.spatial_dims, + num_channels=self.num_channels[i], + num_res_channels=self.num_res_channels[i], + adn_ordering=self.adn_ordering, + act=self.act, + dropout=self.dropout, + ) + ) + + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.num_channels[len(self.num_channels) - 1], + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class Decoder(nn.Module): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_channels: Sequence[int], + num_res_layers: int, + num_res_channels: Sequence[int], + upsample_parameters: Sequence[Sequence[int, int, int, int], ...], + adn_ordering: str, + dropout: tuple | str | float | None, + act: tuple | str | None, + output_act: tuple | str | None, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.num_channels = num_channels + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + self.upsample_parameters = upsample_parameters + self.adn_ordering = adn_ordering + self.dropout = dropout + self.act = act + self.output_act = output_act + + reversed_num_channels = list(reversed(self.num_channels)) + + blocks = [] + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.in_channels, + out_channels=reversed_num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + reversed_num_res_channels = list(reversed(self.num_res_channels)) + for i in range(len(self.num_channels)): + for _ in range(self.num_res_layers): + blocks.append( + VQVAEResidualUnit( + spatial_dims=self.spatial_dims, + num_channels=reversed_num_channels[i], + num_res_channels=reversed_num_res_channels[i], + adn_ordering=self.adn_ordering, + act=self.act, + dropout=self.dropout, + ) + ) + + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=reversed_num_channels[i], + out_channels=self.out_channels if i == len(self.num_channels) - 1 else reversed_num_channels[i + 1], + strides=self.upsample_parameters[i][0], + kernel_size=self.upsample_parameters[i][1], + adn_ordering=self.adn_ordering, + act=self.act, + dropout=self.dropout if i != len(self.num_channels) - 1 else None, + norm=None, + dilation=self.upsample_parameters[i][2], + conv_only=i == len(self.num_channels) - 1, + is_transposed=True, + padding=self.upsample_parameters[i][3], + output_padding=self.upsample_parameters[i][4], + ) + ) + + if self.output_act: + blocks.append(Act[self.output_act]()) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + class VQVAE(nn.Module): """ Single bottleneck implementation of Vector-Quantised Variational Autoencoder (VQ-VAE) as originally used in @@ -133,8 +298,8 @@ def __init__( num_channels: Sequence[int] | int = (96, 96, 192), num_res_layers: int = 3, num_res_channels: Sequence[int] | int = (96, 96, 192), - downsample_parameters: tuple[tuple[int, int, int, int], ...] = ((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1)), - upsample_parameters: tuple[tuple[int, int, int, int, int], ...] = ( + downsample_parameters: Sequence[Sequence[int, int, int, int], ...] = ((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1)), + upsample_parameters: Sequence[Sequence[int, int, int, int, int], ...] = ( (2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0), @@ -157,6 +322,8 @@ def __init__( self.out_channels = out_channels self.spatial_dims = spatial_dims self.num_channels = num_channels + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim if isinstance(num_res_channels, int): num_res_channels = ensure_tuple_rep(num_res_channels, len(num_channels)) @@ -177,176 +344,49 @@ def __init__( "`upsample_parameters` should be a tuple of tuples with the same length as `num_channels`." ) - self.downsample_parameters = downsample_parameters - self.upsample_parameters = upsample_parameters self.num_res_layers = num_res_layers self.num_res_channels = num_res_channels - self.dropout = dropout - self.act = act - self.adn_ordering = adn_ordering - - self.output_act = output_act - - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.embedding_init = embedding_init - self.commitment_cost = commitment_cost - self.decay = decay - self.epsilon = epsilon - - self.ddp_sync = ddp_sync - - self.encoder = self.construct_encoder() - self.quantizer = self.construct_quantizer() - self.decoder = self.construct_decoder() - - def construct_encoder(self) -> torch.nn.Sequential: - encoder = [] - - for idx in range(len(self.num_channels)): - encoder.append( - Convolution( - spatial_dims=self.spatial_dims, - in_channels=self.in_channels if idx == 0 else self.num_channels[idx - 1], - out_channels=self.num_channels[idx], - strides=self.downsample_parameters[idx][0], - kernel_size=self.downsample_parameters[idx][1], - adn_ordering=self.adn_ordering, - act=self.act, - norm=None, - dropout=None if idx == 0 else self.dropout, - dropout_dim=1, - dilation=self.downsample_parameters[idx][2], - groups=1, - bias=True, - conv_only=False, - is_transposed=False, - padding=self.downsample_parameters[idx][3], - output_padding=None, - ) - ) - - for _ in range(self.num_res_layers): - encoder.append( - VQVAEResidualUnit( - spatial_dims=self.spatial_dims, - num_channels=self.num_channels[idx], - num_res_channels=self.num_res_channels[idx], - adn_ordering=self.adn_ordering, - act=self.act, - dropout=self.dropout, - dropout_dim=1, - bias=True, - ) - ) - - encoder.append( - Convolution( - spatial_dims=self.spatial_dims, - in_channels=self.num_channels[len(self.num_channels) - 1], - out_channels=self.embedding_dim, - strides=1, - kernel_size=3, - adn_ordering=self.adn_ordering, - act=None, - norm=None, - dropout=None, - dropout_dim=1, - dilation=1, - bias=True, - conv_only=True, - is_transposed=False, - padding=1, - output_padding=None, - ) + self.encoder = Encoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=embedding_dim, + num_channels=num_channels, + num_res_layers=num_res_layers, + num_res_channels=num_res_channels, + downsample_parameters=downsample_parameters, + adn_ordering=adn_ordering, + dropout=dropout, + act=act, ) - return torch.nn.Sequential(*encoder) + self.decoder = Decoder( + spatial_dims=spatial_dims, + in_channels=embedding_dim, + out_channels=out_channels, + num_channels=num_channels, + num_res_layers=num_res_layers, + num_res_channels=num_res_channels, + upsample_parameters=upsample_parameters, + adn_ordering=adn_ordering, + dropout=dropout, + act=act, + output_act=output_act, + ) - # TODO: Include lucidrains' vector quantizer as an option - def construct_quantizer(self) -> torch.nn.Module: - return VectorQuantizer( + self.quantizer = VectorQuantizer( quantizer=EMAQuantizer( - spatial_dims=self.spatial_dims, - num_embeddings=self.num_embeddings, - embedding_dim=self.embedding_dim, - commitment_cost=self.commitment_cost, - decay=self.decay, - epsilon=self.epsilon, - embedding_init=self.embedding_init, - ddp_sync=self.ddp_sync, + spatial_dims=spatial_dims, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + commitment_cost=commitment_cost, + decay=decay, + epsilon=epsilon, + embedding_init=embedding_init, + ddp_sync=ddp_sync, ) ) - def construct_decoder(self) -> torch.nn.Sequential: - decoder_num_channels = list(reversed(self.num_channels)) - decoder_num_res_channels = list(reversed(self.num_res_channels)) - - decoder = [ - Convolution( - spatial_dims=self.spatial_dims, - in_channels=self.embedding_dim, - out_channels=decoder_num_channels[0], - strides=1, - kernel_size=3, - adn_ordering=self.adn_ordering, - act=None, - dropout=None, - norm=None, - dropout_dim=1, - dilation=1, - bias=True, - is_transposed=False, - padding=1, - output_padding=None, - ) - ] - - for idx in range(len(self.num_channels)): - for _ in range(self.num_res_layers): - decoder.append( - VQVAEResidualUnit( - spatial_dims=self.spatial_dims, - num_channels=decoder_num_channels[idx], - num_res_channels=decoder_num_res_channels[idx], - adn_ordering=self.adn_ordering, - act=self.act, - dropout=self.dropout, - dropout_dim=1, - bias=True, - ) - ) - - decoder.append( - Convolution( - spatial_dims=self.spatial_dims, - in_channels=decoder_num_channels[idx], - out_channels=self.out_channels - if idx == len(self.num_channels) - 1 - else decoder_num_channels[idx + 1], - strides=self.upsample_parameters[idx][0], - kernel_size=self.upsample_parameters[idx][1], - adn_ordering=self.adn_ordering, - act=self.act, - dropout=self.dropout if idx != len(self.num_channels) - 1 else None, - norm=None, - dropout_dim=1, - dilation=self.upsample_parameters[idx][2], - groups=1, - bias=True, - conv_only=idx == len(self.num_channels) - 1, - is_transposed=True, - padding=self.upsample_parameters[idx][3], - output_padding=self.upsample_parameters[idx][4], - ) - ) - - if self.output_act: - decoder.append(Act[self.output_act]()) - - return torch.nn.Sequential(*decoder) - def encode(self, images: torch.Tensor) -> torch.Tensor: return self.encoder(images) From 507ae2bf318076c9e66923c8ca439c6611134593 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 11 Feb 2023 16:00:15 +0000 Subject: [PATCH 04/15] Remove unused dropout_dim Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/vqvae.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index 015b1375..cb7e07ea 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -38,7 +38,6 @@ class VQVAEResidualUnit(nn.Module): adn_ordering : a string representing the ordering of activation, normalization, and dropout. Defaults to "NDA". act: activation type and arguments. Defaults to RELU. dropout: dropout ratio. Defaults to no dropout. - dropout_dim: dimension along which to apply dropout. Defaults to 1. bias: whether to have a bias term. Defaults to True. """ @@ -50,7 +49,6 @@ def __init__( adn_ordering: str = "NDA", act: tuple | str | None = "RELU", dropout: tuple | str | float | None = None, - dropout_dim: int | None = 1, bias: bool = True, ) -> None: super().__init__() @@ -61,7 +59,6 @@ def __init__( self.adn_ordering = adn_ordering self.act = act self.dropout = dropout - self.dropout_dim = dropout_dim self.bias = bias self.conv1 = Convolution( @@ -72,7 +69,6 @@ def __init__( act=self.act, norm=None, dropout=self.dropout, - dropout_dim=self.dropout_dim, bias=self.bias, ) @@ -80,12 +76,8 @@ def __init__( spatial_dims=self.spatial_dims, in_channels=self.num_res_channels, out_channels=self.num_channels, - adn_ordering=self.adn_ordering, - act=None, - norm=None, - dropout=None, - dropout_dim=self.dropout_dim, bias=self.bias, + conv_only=True, ) def forward(self, x): From 3cdc5cffdc1bb2f813dd42f2da761bda546448fe Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 11 Feb 2023 17:03:54 +0000 Subject: [PATCH 05/15] Add more checks for the parameters and tests Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/vqvae.py | 30 ++++-- tests/test_vqvae.py | 158 +++++++++++++++--------------- 2 files changed, 103 insertions(+), 85 deletions(-) diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index cb7e07ea..34fba80f 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -290,12 +290,10 @@ def __init__( num_channels: Sequence[int] | int = (96, 96, 192), num_res_layers: int = 3, num_res_channels: Sequence[int] | int = (96, 96, 192), - downsample_parameters: Sequence[Sequence[int, int, int, int], ...] = ((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1)), - upsample_parameters: Sequence[Sequence[int, int, int, int, int], ...] = ( - (2, 4, 1, 1, 0), - (2, 4, 1, 1, 0), - (2, 4, 1, 1, 0), - ), + downsample_parameters: Sequence[Sequence[int, int, int, int], ...] + | Sequence[int, int, int, int] = ((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1)), + upsample_parameters: Sequence[Sequence[int, int, int, int, int], ...] + | Sequence[int, int, int, int] = ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), num_embeddings: int = 32, embedding_dim: int = 64, embedding_init: str = "normal", @@ -326,6 +324,26 @@ def __init__( "`num_channels`." ) + if not all(isinstance(values, (int, Sequence)) for values in downsample_parameters): + raise ValueError("`downsample_parameters` should be a single tuple of integer or a tuple of tuples.") + + if not all(isinstance(values, (int, Sequence)) for values in upsample_parameters): + raise ValueError("`upsample_parameters` should be a single tuple of integer or a tuple of tuples.") + + if all(isinstance(values, int) for values in upsample_parameters): + upsample_parameters = (upsample_parameters,) * len(num_channels) + + if all(isinstance(values, int) for values in downsample_parameters): + downsample_parameters = (downsample_parameters,) * len(num_channels) + + for parameter in downsample_parameters: + if len(parameter) != 4: + raise ValueError("`downsample_parameters` should be a tuple of tuples with 4 integers.") + + for parameter in upsample_parameters: + if len(parameter) != 5: + raise ValueError("`upsample_parameters` should be a tuple of tuples with 5 integers.") + if len(downsample_parameters) != len(num_channels): raise ValueError( "`downsample_parameters` should be a tuple of tuples with the same length as `num_channels`." diff --git a/tests/test_vqvae.py b/tests/test_vqvae.py index 264b743b..bb5486ee 100644 --- a/tests/test_vqvae.py +++ b/tests/test_vqvae.py @@ -26,96 +26,64 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": (8, 8), + "num_channels": (4, 4), "num_res_layers": 1, - "num_res_channels": [8, 8], - "downsample_parameters": [(2, 4, 1, 1)] * 2, - "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, - "num_embeddings": 16, + "num_res_channels": (4, 4), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, "embedding_dim": 8, - "embedding_init": "normal", - "commitment_cost": 0.25, - "decay": 0.5, - "epsilon": 1e-5, - "adn_ordering": "NDA", - "dropout": 0.1, - "act": "RELU", - "output_act": None, }, - (1, 1, 16, 16), - (1, 1, 16, 16), + (1, 1, 8, 8), + (1, 1, 8, 8), ], [ { "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": (8, 8), + "num_channels": (4, 4), "num_res_layers": 1, - "num_res_channels": 8, - "downsample_parameters": [(2, 4, 1, 1)] * 2, - "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, - "num_embeddings": 16, + "num_res_channels": 4, + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, "embedding_dim": 8, - "embedding_init": "normal", - "commitment_cost": 0.25, - "decay": 0.5, - "epsilon": 1e-5, - "adn_ordering": "NDA", - "dropout": 0.1, - "act": "RELU", - "output_act": None, }, - (1, 1, 16, 16), - (1, 1, 16, 16), + (1, 1, 8, 8), + (1, 1, 8, 8), ], [ { "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": [8, 8], + "num_channels": (4, 4), "num_res_layers": 1, - "num_res_channels": [8, 8], - "downsample_parameters": [(2, 4, 1, 1)] * 2, - "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, - "num_embeddings": 16, + "num_res_channels": (4, 4), + "downsample_parameters": (2, 4, 1, 1), + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, "embedding_dim": 8, - "embedding_init": "normal", - "commitment_cost": 0.25, - "decay": 0.5, - "epsilon": 1e-5, - "adn_ordering": "NDA", - "dropout": 0.1, - "act": "RELU", - "output_act": None, }, - (1, 1, 16, 16), - (1, 1, 16, 16), + (1, 1, 8, 8), + (1, 1, 8, 8), ], [ { - "spatial_dims": 3, + "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "downsample_parameters": [(2, 4, 1, 1)] * 2, - "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, + "num_channels": (4, 4), "num_res_layers": 1, - "num_channels": [8, 8], - "num_res_channels": [8, 8], - "num_embeddings": 16, + "num_res_channels": (4, 4), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": (2, 4, 1, 1, 0), + "num_embeddings": 8, "embedding_dim": 8, - "embedding_init": "normal", - "commitment_cost": 0.25, - "decay": 0.5, - "epsilon": 1e-5, - "adn_ordering": "NDA", - "dropout": 0.1, - "act": "RELU", - "output_act": None, }, - (1, 1, 16, 16, 16), - (1, 1, 16, 16, 16), + (1, 1, 8, 8), + (1, 1, 8, 8), ], ] @@ -123,21 +91,13 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "downsample_parameters": [(2, 4, 1, 1)] * 2, - "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, "num_res_layers": 1, "num_channels": [8, 8], "num_res_channels": [8, 8], "num_embeddings": 16, "embedding_dim": 8, - "embedding_init": "normal", - "commitment_cost": 0.25, - "decay": 0.5, - "epsilon": 1e-5, - "adn_ordering": "NDA", - "dropout": 0.1, - "act": "RELU", - "output_act": None, } @@ -165,14 +125,6 @@ def test_script(self): num_res_channels=[8, 8], num_embeddings=16, embedding_dim=8, - embedding_init="normal", - commitment_cost=0.25, - decay=0.5, - epsilon=1e-5, - adn_ordering="NDA", - dropout=0.1, - act="RELU", - output_act=None, ddp_sync=False, ) test_data = torch.randn(1, 1, 16, 16) @@ -214,6 +166,54 @@ def test_num_channels_not_same_size_of_upsample_parameters(self): upsample_parameters=((2, 4, 1, 1, 0),) * 3, ) + def test_downsample_parameters_not_sequence_or_int(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=(("test", 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_upsample_parameters_not_sequence_or_int(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=(("test", 4, 1, 1, 0),) * 2, + ) + + def test_downsample_parameter_length_different_4(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 3, + ) + + def test_upsample_parameter_length_different_5(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0, 1),) * 3, + ) + def test_encode_shape(self): device = "cuda" if torch.cuda.is_available() else "cpu" From 0d7668026eaf09236c4f5bb1a69ca400b57c8d1d Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 11 Feb 2023 17:25:04 +0000 Subject: [PATCH 06/15] Add annotations and remove __constants__ Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/layers/vector_quantizer.py | 10 ++++++---- generative/networks/nets/vqvae.py | 3 --- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/generative/networks/layers/vector_quantizer.py b/generative/networks/layers/vector_quantizer.py index 661f2129..b06ffb3e 100644 --- a/generative/networks/layers/vector_quantizer.py +++ b/generative/networks/layers/vector_quantizer.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence +from __future__ import annotations + +from collections.abc import Sequence import torch from torch import nn @@ -84,7 +86,7 @@ def __init__( ) @torch.cuda.amp.autocast(enabled=False) - def quantize(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def quantize(self, inputs: torch.Tensor) -> Sequence[torch.Tensor, torch.Tensor, torch.Tensor]: """ Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss. @@ -158,7 +160,7 @@ def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Ten else: pass - def forward(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def forward(self, inputs: torch.Tensor) -> Sequence[torch.Tensor, torch.Tensor, torch.Tensor]: flat_input, encodings, encoding_indices = self.quantize(inputs) quantized = self.embed(encoding_indices) @@ -205,7 +207,7 @@ def __init__(self, quantizer: torch.nn.Module = None): self.perplexity: torch.Tensor = torch.rand(1) - def forward(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def forward(self, inputs: torch.Tensor) -> Sequence[torch.Tensor, torch.Tensor]: quantized, loss, encoding_indices = self.quantizer(inputs) # Perplexity calculations diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index 34fba80f..29eeec2a 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -279,9 +279,6 @@ class VQVAE(nn.Module): ddp_sync: whether to synchronize the codebook across processes. """ - # < Python 3.9 TorchScript requirement for ModuleList - __constants__ = ["encoder", "quantizer", "decoder"] - def __init__( self, spatial_dims: int, From acc513758b2b14c72be2f52a44d4389af1ba007c Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 11 Feb 2023 17:37:55 +0000 Subject: [PATCH 07/15] Add docstring for Encoder and Decoder Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/vqvae.py | 36 +++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index 29eeec2a..eea035b7 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -85,6 +85,24 @@ def forward(self, x): class Encoder(nn.Module): + """ + Encoder module for VQ-VAE. + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of input channels. + out_channels: VectorQuantization number of channels of the input and atomic elements. + num_channels: number of channels at each level. + num_res_layers: number of sequential residual layers at each level. + num_res_channels: number of channels in the residual layers at each level. + upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). + If use_subpixel_conv is True, only the stride will be used for the last conv as the scale_factor. + adn_ordering: a string representing the ordering of activation, normalization, and dropout, e.g. "NDA". + dropout: dropout ratio. + act: activation type and arguments. + + """ def __init__( self, spatial_dims: int, @@ -163,6 +181,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Decoder(nn.Module): + """ + Decoder module for VQ-VAE. + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: VectorQuantization number of channels of the input and atomic elements. + out_channels: number of output channels. + num_channels: number of channels at each level. + num_res_layers: number of sequential residual layers at each level. + num_res_channels: number of channels in the residual layers at each level. + downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int) and padding (int). + adn_ordering: a string representing the ordering of activation, normalization, and dropout, e.g. "NDA". + dropout: dropout ratio. + act: activation type and arguments. + output_act: activation type and arguments for the output. + """ def __init__( self, spatial_dims: int, @@ -276,6 +311,7 @@ class VQVAE(nn.Module): adn_ordering: a string representing the ordering of activation, normalization, and dropout, e.g. "NDA". act: activation type and arguments. dropout: dropout ratio. + output_act: activation type and arguments for the output. ddp_sync: whether to synchronize the codebook across processes. """ From 5e679c8f65370e4823d8b98a13cc56c2dd22333c Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 11 Feb 2023 17:38:49 +0000 Subject: [PATCH 08/15] Add docstring for Encoder and Decoder Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/vqvae.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index eea035b7..51c37817 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -95,13 +95,11 @@ class Encoder(nn.Module): num_channels: number of channels at each level. num_res_layers: number of sequential residual layers at each level. num_res_channels: number of channels in the residual layers at each level. - upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the - following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). - If use_subpixel_conv is True, only the stride will be used for the last conv as the scale_factor. + downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int) and padding (int). adn_ordering: a string representing the ordering of activation, normalization, and dropout, e.g. "NDA". dropout: dropout ratio. act: activation type and arguments. - """ def __init__( self, @@ -191,8 +189,9 @@ class Decoder(nn.Module): num_channels: number of channels at each level. num_res_layers: number of sequential residual layers at each level. num_res_channels: number of channels in the residual layers at each level. - downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the - following information stride (int), kernel_size (int), dilation (int) and padding (int). + upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). + If use_subpixel_conv is True, only the stride will be used for the last conv as the scale_factor. adn_ordering: a string representing the ordering of activation, normalization, and dropout, e.g. "NDA". dropout: dropout ratio. act: activation type and arguments. From 84a2fd4db5e718acff9eb2fa95639b84e320d2a3 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 11 Feb 2023 17:44:00 +0000 Subject: [PATCH 09/15] Set dropout as float value and update docstring Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/vqvae.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index 51c37817..e86de8d0 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -48,7 +48,7 @@ def __init__( num_res_channels: int, adn_ordering: str = "NDA", act: tuple | str | None = "RELU", - dropout: tuple | str | float | None = None, + dropout: float = None, bias: bool = True, ) -> None: super().__init__() @@ -111,7 +111,7 @@ def __init__( num_res_channels: Sequence[int], downsample_parameters: Sequence[Sequence[int, int, int, int], ...], adn_ordering: str, - dropout: tuple | str | float | None, + dropout: float, act: tuple | str | None, ) -> None: super().__init__() @@ -207,7 +207,7 @@ def __init__( num_res_channels: Sequence[int], upsample_parameters: Sequence[Sequence[int, int, int, int], ...], adn_ordering: str, - dropout: tuple | str | float | None, + dropout: float, act: tuple | str | None, output_act: tuple | str | None, ) -> None: @@ -285,10 +285,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class VQVAE(nn.Module): """ - Single bottleneck implementation of Vector-Quantised Variational Autoencoder (VQ-VAE) as originally used in - Morphology-preserving Autoregressive 3D Generative Modelling of the Brain by Tudosiu et al. - (https://arxiv.org/pdf/2209.03177.pdf) and the original implementation that can be found at - https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L163/ + Vector-Quantised Variational Autoencoder (VQ-VAE) used in Morphology-preserving Autoregressive 3D Generative + Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf) and the original implementation + that can be found at https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L163/ Args: spatial_dims: number of spatial spatial_dims. @@ -333,7 +332,7 @@ def __init__( decay: float = 0.5, epsilon: float = 1e-5, adn_ordering: str = "NDA", - dropout: tuple | str | float | None = 0.1, + dropout: float = 0.0, act: tuple | str | None = "RELU", output_act: tuple | str | None = None, ddp_sync: bool = True, From cbab9745413222e5d77bf3fa7c77061541195e53 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 11 Feb 2023 18:07:16 +0000 Subject: [PATCH 10/15] Fix torchscript error Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/layers/vector_quantizer.py | 10 ++++------ generative/networks/nets/vqvae.py | 6 ++++-- tests/test_vqvae.py | 12 ++++++------ 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/generative/networks/layers/vector_quantizer.py b/generative/networks/layers/vector_quantizer.py index b06ffb3e..661f2129 100644 --- a/generative/networks/layers/vector_quantizer.py +++ b/generative/networks/layers/vector_quantizer.py @@ -9,9 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - -from collections.abc import Sequence +from typing import Sequence import torch from torch import nn @@ -86,7 +84,7 @@ def __init__( ) @torch.cuda.amp.autocast(enabled=False) - def quantize(self, inputs: torch.Tensor) -> Sequence[torch.Tensor, torch.Tensor, torch.Tensor]: + def quantize(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss. @@ -160,7 +158,7 @@ def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Ten else: pass - def forward(self, inputs: torch.Tensor) -> Sequence[torch.Tensor, torch.Tensor, torch.Tensor]: + def forward(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: flat_input, encodings, encoding_indices = self.quantize(inputs) quantized = self.embed(encoding_indices) @@ -207,7 +205,7 @@ def __init__(self, quantizer: torch.nn.Module = None): self.perplexity: torch.Tensor = torch.rand(1) - def forward(self, inputs: torch.Tensor) -> Sequence[torch.Tensor, torch.Tensor]: + def forward(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: quantized, loss, encoding_indices = self.quantizer(inputs) # Perplexity calculations diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index e86de8d0..53661df0 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -91,7 +91,7 @@ class Encoder(nn.Module): Args: spatial_dims: number of spatial spatial_dims. in_channels: number of input channels. - out_channels: VectorQuantization number of channels of the input and atomic elements. + out_channels: number of channels in the latent space (embedding_dim). num_channels: number of channels at each level. num_res_layers: number of sequential residual layers at each level. num_res_channels: number of channels in the residual layers at each level. @@ -101,6 +101,7 @@ class Encoder(nn.Module): dropout: dropout ratio. act: activation type and arguments. """ + def __init__( self, spatial_dims: int, @@ -184,7 +185,7 @@ class Decoder(nn.Module): Args: spatial_dims: number of spatial spatial_dims. - in_channels: VectorQuantization number of channels of the input and atomic elements. + in_channels: number of channels in the latent space (embedding_dim). out_channels: number of output channels. num_channels: number of channels at each level. num_res_layers: number of sequential residual layers at each level. @@ -197,6 +198,7 @@ class Decoder(nn.Module): act: activation type and arguments. output_act: activation type and arguments for the output. """ + def __init__( self, spatial_dims: int, diff --git a/tests/test_vqvae.py b/tests/test_vqvae.py index bb5486ee..9d9ccade 100644 --- a/tests/test_vqvae.py +++ b/tests/test_vqvae.py @@ -94,8 +94,8 @@ "downsample_parameters": ((2, 4, 1, 1),) * 2, "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, "num_res_layers": 1, - "num_channels": [8, 8], - "num_res_channels": [8, 8], + "num_channels": (8, 8), + "num_res_channels": (8, 8), "num_embeddings": 16, "embedding_dim": 8, } @@ -118,11 +118,11 @@ def test_script(self): spatial_dims=2, in_channels=1, out_channels=1, - downsample_parameters=tuple([(2, 4, 1, 1)] * 2), - upsample_parameters=tuple([(2, 4, 1, 1, 0)] * 2), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, num_res_layers=1, - num_channels=[8, 8], - num_res_channels=[8, 8], + num_channels=(8, 8), + num_res_channels=(8, 8), num_embeddings=16, embedding_dim=8, ddp_sync=False, From 8b30f18e5d5168ada6fd667acbee86ae68fa9371 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 11 Feb 2023 18:14:06 +0000 Subject: [PATCH 11/15] Update tutorials Signed-off-by: Walter Hugo Lopez Pinaya --- tutorials/generative/2d_vqgan/2d_vqgan_tutorial.py | 5 ++--- tutorials/generative/2d_vqvae/2d_vqvae_tutorial.py | 7 +++---- tutorials/generative/3d_vqvae/3d_vqvae_tutorial.py | 5 ++--- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/tutorials/generative/2d_vqgan/2d_vqgan_tutorial.py b/tutorials/generative/2d_vqgan/2d_vqgan_tutorial.py index a2ebdbf7..89dc0725 100644 --- a/tutorials/generative/2d_vqgan/2d_vqgan_tutorial.py +++ b/tutorials/generative/2d_vqgan/2d_vqgan_tutorial.py @@ -165,12 +165,11 @@ spatial_dims=2, in_channels=1, out_channels=1, + num_channels=(256, 512), + num_res_channels=512, num_res_layers=2, - num_levels=2, downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1)), upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), - num_channels=[256, 512], - num_res_channels=[256, 512], num_embeddings=256, embedding_dim=32, ) diff --git a/tutorials/generative/2d_vqvae/2d_vqvae_tutorial.py b/tutorials/generative/2d_vqvae/2d_vqvae_tutorial.py index 79b3870b..2aae29fb 100644 --- a/tutorials/generative/2d_vqvae/2d_vqvae_tutorial.py +++ b/tutorials/generative/2d_vqvae/2d_vqvae_tutorial.py @@ -118,12 +118,11 @@ spatial_dims=2, in_channels=1, out_channels=1, + num_channels=(256, 256), + num_res_channels=256, + num_res_layers=2, downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1)), upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), - num_res_layers=2, - num_levels=2, - num_channels=256, - num_res_channels=256, num_embeddings=256, embedding_dim=32, ) diff --git a/tutorials/generative/3d_vqvae/3d_vqvae_tutorial.py b/tutorials/generative/3d_vqvae/3d_vqvae_tutorial.py index 1c057a3b..95786e15 100644 --- a/tutorials/generative/3d_vqvae/3d_vqvae_tutorial.py +++ b/tutorials/generative/3d_vqvae/3d_vqvae_tutorial.py @@ -123,12 +123,11 @@ spatial_dims=3, in_channels=1, out_channels=1, + num_channels=(256, 256), + num_res_channels=256, num_res_layers=2, - num_levels=2, downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1)), upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), - num_channels=[256, 256], - num_res_channels=[256, 256], num_embeddings=256, embedding_dim=32, ) From cd197cd375e4f2964798b16cb013b3b66f66a53c Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 11 Feb 2023 18:15:48 +0000 Subject: [PATCH 12/15] Update tutorials Signed-off-by: Walter Hugo Lopez Pinaya --- .../generative/2d_vqgan/2d_vqgan_tutorial.ipynb | 5 ++--- .../generative/2d_vqvae/2d_vqvae_tutorial.ipynb | 17 +++++------------ .../generative/3d_vqvae/3d_vqvae_tutorial.ipynb | 5 ++--- 3 files changed, 9 insertions(+), 18 deletions(-) diff --git a/tutorials/generative/2d_vqgan/2d_vqgan_tutorial.ipynb b/tutorials/generative/2d_vqgan/2d_vqgan_tutorial.ipynb index c2e54256..60063621 100644 --- a/tutorials/generative/2d_vqgan/2d_vqgan_tutorial.ipynb +++ b/tutorials/generative/2d_vqgan/2d_vqgan_tutorial.ipynb @@ -349,12 +349,11 @@ " spatial_dims=2,\n", " in_channels=1,\n", " out_channels=1,\n", + " num_channels=(256, 512),\n", + " num_res_channels=512,\n", " num_res_layers=2,\n", - " num_levels=2,\n", " downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1)),\n", " upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),\n", - " num_channels=[256, 512],\n", - " num_res_channels=[256, 512],\n", " num_embeddings=256,\n", " embedding_dim=32,\n", ")\n", diff --git a/tutorials/generative/2d_vqvae/2d_vqvae_tutorial.ipynb b/tutorials/generative/2d_vqvae/2d_vqvae_tutorial.ipynb index cd7feeae..714ba358 100644 --- a/tutorials/generative/2d_vqvae/2d_vqvae_tutorial.ipynb +++ b/tutorials/generative/2d_vqvae/2d_vqvae_tutorial.ipynb @@ -501,12 +501,11 @@ " spatial_dims=2,\n", " in_channels=1,\n", " out_channels=1,\n", + " num_channels=(256, 256),\n", + " num_res_channels=256,\n", + " num_res_layers=2,\n", " downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1)),\n", " upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),\n", - " num_res_layers=2,\n", - " num_levels=2,\n", - " num_channels=256,\n", - " num_res_channels=256,\n", " num_embeddings=256,\n", " embedding_dim=32,\n", ")\n", @@ -615,13 +614,7 @@ "Epoch 69: 100%|██████████████| 125/125 [00:32<00:00, 3.89it/s, recons_loss=0.0162, quantization_loss=2.33e-5]\n", "Epoch 70: 100%|███████████████| 125/125 [00:32<00:00, 3.88it/s, recons_loss=0.0162, quantization_loss=2.5e-5]\n", "Epoch 71: 100%|██████████████| 125/125 [00:32<00:00, 3.89it/s, recons_loss=0.0168, quantization_loss=2.34e-5]\n", - "Epoch 72: 100%|██████████████| 125/125 [00:32<00:00, 3.89it/s, recons_loss=0.0171, quantization_loss=2.01e-5]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ + "Epoch 72: 100%|██████████████| 125/125 [00:32<00:00, 3.89it/s, recons_loss=0.0171, quantization_loss=2.01e-5]\n", "Epoch 73: 100%|██████████████| 125/125 [00:32<00:00, 3.89it/s, recons_loss=0.0166, quantization_loss=2.05e-5]\n", "Epoch 74: 100%|██████████████| 125/125 [00:32<00:00, 3.89it/s, recons_loss=0.0165, quantization_loss=2.36e-5]\n", "Epoch 75: 100%|██████████████| 125/125 [00:32<00:00, 3.89it/s, recons_loss=0.0161, quantization_loss=1.96e-5]\n", @@ -890,4 +883,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/tutorials/generative/3d_vqvae/3d_vqvae_tutorial.ipynb b/tutorials/generative/3d_vqvae/3d_vqvae_tutorial.ipynb index a898f0cc..a20a5393 100644 --- a/tutorials/generative/3d_vqvae/3d_vqvae_tutorial.ipynb +++ b/tutorials/generative/3d_vqvae/3d_vqvae_tutorial.ipynb @@ -471,12 +471,11 @@ " spatial_dims=3,\n", " in_channels=1,\n", " out_channels=1,\n", + " num_channels=(256, 256),\n", + " num_res_channels=256,\n", " num_res_layers=2,\n", - " num_levels=2,\n", " downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1)),\n", " upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),\n", - " num_channels=[256, 256],\n", - " num_res_channels=[256, 256],\n", " num_embeddings=256,\n", " embedding_dim=32,\n", ")\n", From 22524c3d35295974ff7e67e9aab514803676961a Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 13 Feb 2023 22:14:57 +0000 Subject: [PATCH 13/15] remove adn_ordering Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/vqvae.py | 23 +++-------------------- tests/test_vqvaetransformer_inferer.py | 15 ++++++--------- 2 files changed, 9 insertions(+), 29 deletions(-) diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index 53661df0..6dc73538 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -35,7 +35,6 @@ class VQVAEResidualUnit(nn.Module): spatial_dims: number of spatial spatial_dims of the input data. num_channels: number of input channels. num_res_channels: number of channels in the residual layers. - adn_ordering : a string representing the ordering of activation, normalization, and dropout. Defaults to "NDA". act: activation type and arguments. Defaults to RELU. dropout: dropout ratio. Defaults to no dropout. bias: whether to have a bias term. Defaults to True. @@ -46,7 +45,6 @@ def __init__( spatial_dims: int, num_channels: int, num_res_channels: int, - adn_ordering: str = "NDA", act: tuple | str | None = "RELU", dropout: float = None, bias: bool = True, @@ -56,7 +54,6 @@ def __init__( self.spatial_dims = spatial_dims self.num_channels = num_channels self.num_res_channels = num_res_channels - self.adn_ordering = adn_ordering self.act = act self.dropout = dropout self.bias = bias @@ -65,9 +62,8 @@ def __init__( spatial_dims=self.spatial_dims, in_channels=self.num_channels, out_channels=self.num_res_channels, - adn_ordering=self.adn_ordering, + adn_ordering="DA", act=self.act, - norm=None, dropout=self.dropout, bias=self.bias, ) @@ -97,7 +93,6 @@ class Encoder(nn.Module): num_res_channels: number of channels in the residual layers at each level. downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the following information stride (int), kernel_size (int), dilation (int) and padding (int). - adn_ordering: a string representing the ordering of activation, normalization, and dropout, e.g. "NDA". dropout: dropout ratio. act: activation type and arguments. """ @@ -111,7 +106,6 @@ def __init__( num_res_layers: int, num_res_channels: Sequence[int], downsample_parameters: Sequence[Sequence[int, int, int, int], ...], - adn_ordering: str, dropout: float, act: tuple | str | None, ) -> None: @@ -123,7 +117,6 @@ def __init__( self.num_res_layers = num_res_layers self.num_res_channels = num_res_channels self.downsample_parameters = downsample_parameters - self.adn_ordering = adn_ordering self.dropout = dropout self.act = act @@ -137,9 +130,8 @@ def __init__( out_channels=self.num_channels[i], strides=self.downsample_parameters[i][0], kernel_size=self.downsample_parameters[i][1], - adn_ordering=self.adn_ordering, + adn_ordering="DA", act=self.act, - norm=None, dropout=None if i == 0 else self.dropout, dropout_dim=1, dilation=self.downsample_parameters[i][2], @@ -153,7 +145,6 @@ def __init__( spatial_dims=self.spatial_dims, num_channels=self.num_channels[i], num_res_channels=self.num_res_channels[i], - adn_ordering=self.adn_ordering, act=self.act, dropout=self.dropout, ) @@ -193,7 +184,6 @@ class Decoder(nn.Module): upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). If use_subpixel_conv is True, only the stride will be used for the last conv as the scale_factor. - adn_ordering: a string representing the ordering of activation, normalization, and dropout, e.g. "NDA". dropout: dropout ratio. act: activation type and arguments. output_act: activation type and arguments for the output. @@ -208,7 +198,6 @@ def __init__( num_res_layers: int, num_res_channels: Sequence[int], upsample_parameters: Sequence[Sequence[int, int, int, int], ...], - adn_ordering: str, dropout: float, act: tuple | str | None, output_act: tuple | str | None, @@ -221,7 +210,6 @@ def __init__( self.num_res_layers = num_res_layers self.num_res_channels = num_res_channels self.upsample_parameters = upsample_parameters - self.adn_ordering = adn_ordering self.dropout = dropout self.act = act self.output_act = output_act @@ -249,7 +237,6 @@ def __init__( spatial_dims=self.spatial_dims, num_channels=reversed_num_channels[i], num_res_channels=reversed_num_res_channels[i], - adn_ordering=self.adn_ordering, act=self.act, dropout=self.dropout, ) @@ -262,7 +249,7 @@ def __init__( out_channels=self.out_channels if i == len(self.num_channels) - 1 else reversed_num_channels[i + 1], strides=self.upsample_parameters[i][0], kernel_size=self.upsample_parameters[i][1], - adn_ordering=self.adn_ordering, + adn_ordering="DA", act=self.act, dropout=self.dropout if i != len(self.num_channels) - 1 else None, norm=None, @@ -308,7 +295,6 @@ class VQVAE(nn.Module): commitment_cost: VectorQuantization commitment_cost. decay: VectorQuantization decay. epsilon: VectorQuantization epsilon. - adn_ordering: a string representing the ordering of activation, normalization, and dropout, e.g. "NDA". act: activation type and arguments. dropout: dropout ratio. output_act: activation type and arguments for the output. @@ -333,7 +319,6 @@ def __init__( commitment_cost: float = 0.25, decay: float = 0.5, epsilon: float = 1e-5, - adn_ordering: str = "NDA", dropout: float = 0.0, act: tuple | str | None = "RELU", output_act: tuple | str | None = None, @@ -398,7 +383,6 @@ def __init__( num_res_layers=num_res_layers, num_res_channels=num_res_channels, downsample_parameters=downsample_parameters, - adn_ordering=adn_ordering, dropout=dropout, act=act, ) @@ -411,7 +395,6 @@ def __init__( num_res_layers=num_res_layers, num_res_channels=num_res_channels, upsample_parameters=upsample_parameters, - adn_ordering=adn_ordering, dropout=dropout, act=act, output_act=output_act, diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py index edb152a3..0c10a31a 100644 --- a/tests/test_vqvaetransformer_inferer.py +++ b/tests/test_vqvaetransformer_inferer.py @@ -26,12 +26,11 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_levels": 2, + "num_channels": (8, 8), + "num_res_channels": (8, 8), "downsample_parameters": ((2, 4, 1, 1),) * 2, "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, "num_res_layers": 1, - "num_channels": 8, - "num_res_channels": [8, 8], "num_embeddings": 16, "embedding_dim": 8, }, @@ -52,12 +51,11 @@ "spatial_dims": 3, "in_channels": 1, "out_channels": 1, - "num_levels": 2, + "num_channels": (8, 8), + "num_res_channels": (8, 8), "downsample_parameters": ((2, 4, 1, 1),) * 2, "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, "num_res_layers": 1, - "num_channels": 8, - "num_res_channels": [8, 8], "num_embeddings": 16, "embedding_dim": 8, }, @@ -100,12 +98,11 @@ def test_sample(self): spatial_dims=2, in_channels=1, out_channels=1, - num_levels=2, + num_channels=(8, 8), + num_res_channels=(8, 8), downsample_parameters=((2, 4, 1, 1),) * 2, upsample_parameters=((2, 4, 1, 1, 0),) * 2, num_res_layers=1, - num_channels=8, - num_res_channels=(8, 8), num_embeddings=16, embedding_dim=8, ) From 6caa9389e9f3fcebc12ed95f3cf88c7def21a5e5 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Tue, 14 Feb 2023 11:14:08 +0000 Subject: [PATCH 14/15] Add ActType Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/vqvae.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index 6dc73538..a12d2d98 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -23,6 +23,8 @@ __all__ = ["VQVAE"] +ActType = tuple | str | None + class VQVAEResidualUnit(nn.Module): """ @@ -45,8 +47,8 @@ def __init__( spatial_dims: int, num_channels: int, num_res_channels: int, - act: tuple | str | None = "RELU", - dropout: float = None, + act: ActType = Act.RELU, + dropout: float = 0.0, bias: bool = True, ) -> None: super().__init__() @@ -107,7 +109,7 @@ def __init__( num_res_channels: Sequence[int], downsample_parameters: Sequence[Sequence[int, int, int, int], ...], dropout: float, - act: tuple | str | None, + act: ActType, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -199,8 +201,8 @@ def __init__( num_res_channels: Sequence[int], upsample_parameters: Sequence[Sequence[int, int, int, int], ...], dropout: float, - act: tuple | str | None, - output_act: tuple | str | None, + act: ActType, + output_act: ActType, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -320,8 +322,8 @@ def __init__( decay: float = 0.5, epsilon: float = 1e-5, dropout: float = 0.0, - act: tuple | str | None = "RELU", - output_act: tuple | str | None = None, + act: ActType = Act.RELU, + output_act: ActType = None, ddp_sync: bool = True, ): super().__init__() From b849095b17ee8d323f79d1b17249e350f568c3ca Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Wed, 15 Feb 2023 17:54:58 +0000 Subject: [PATCH 15/15] Remove text about subpixel layers and ActType. Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/vqvae.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index a12d2d98..39335587 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -23,8 +23,6 @@ __all__ = ["VQVAE"] -ActType = tuple | str | None - class VQVAEResidualUnit(nn.Module): """ @@ -47,7 +45,7 @@ def __init__( spatial_dims: int, num_channels: int, num_res_channels: int, - act: ActType = Act.RELU, + act: tuple | str | None = Act.RELU, dropout: float = 0.0, bias: bool = True, ) -> None: @@ -109,7 +107,7 @@ def __init__( num_res_channels: Sequence[int], downsample_parameters: Sequence[Sequence[int, int, int, int], ...], dropout: float, - act: ActType, + act: tuple | str | None, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -185,7 +183,6 @@ class Decoder(nn.Module): num_res_channels: number of channels in the residual layers at each level. upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). - If use_subpixel_conv is True, only the stride will be used for the last conv as the scale_factor. dropout: dropout ratio. act: activation type and arguments. output_act: activation type and arguments for the output. @@ -201,8 +198,8 @@ def __init__( num_res_channels: Sequence[int], upsample_parameters: Sequence[Sequence[int, int, int, int], ...], dropout: float, - act: ActType, - output_act: ActType, + act: tuple | str | None, + output_act: tuple | str | None, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -288,7 +285,6 @@ class VQVAE(nn.Module): following information stride (int), kernel_size (int), dilation (int) and padding (int). upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). - If use_subpixel_conv is True, only the stride will be used for the last conv as the scale_factor. num_res_layers: number of sequential residual layers at each level. num_channels: number of channels at each level. num_res_channels: number of channels in the residual layers at each level. @@ -322,8 +318,8 @@ def __init__( decay: float = 0.5, epsilon: float = 1e-5, dropout: float = 0.0, - act: ActType = Act.RELU, - output_act: ActType = None, + act: tuple | str | None = Act.RELU, + output_act: tuple | str | None = None, ddp_sync: bool = True, ): super().__init__()