diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index feceed13..39335587 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -35,10 +35,8 @@ class VQVAEResidualUnit(nn.Module): spatial_dims: number of spatial spatial_dims of the input data. num_channels: number of input channels. num_res_channels: number of channels in the residual layers. - adn_ordering : a string representing the ordering of activation, normalization, and dropout. Defaults to "NDA". act: activation type and arguments. Defaults to RELU. dropout: dropout ratio. Defaults to no dropout. - dropout_dim: dimension along which to apply dropout. Defaults to 1. bias: whether to have a bias term. Defaults to True. """ @@ -47,10 +45,8 @@ def __init__( spatial_dims: int, num_channels: int, num_res_channels: int, - adn_ordering: str = "NDA", - act: tuple | str | None = "RELU", - dropout: tuple | str | float | None = None, - dropout_dim: int | None = 1, + act: tuple | str | None = Act.RELU, + dropout: float = 0.0, bias: bool = True, ) -> None: super().__init__() @@ -58,21 +54,17 @@ def __init__( self.spatial_dims = spatial_dims self.num_channels = num_channels self.num_res_channels = num_res_channels - self.adn_ordering = adn_ordering self.act = act self.dropout = dropout - self.dropout_dim = dropout_dim self.bias = bias self.conv1 = Convolution( spatial_dims=self.spatial_dims, in_channels=self.num_channels, out_channels=self.num_res_channels, - adn_ordering=self.adn_ordering, + adn_ordering="DA", act=self.act, - norm=None, dropout=self.dropout, - dropout_dim=self.dropout_dim, bias=self.bias, ) @@ -80,270 +72,344 @@ def __init__( spatial_dims=self.spatial_dims, in_channels=self.num_res_channels, out_channels=self.num_channels, - adn_ordering=self.adn_ordering, - act=None, - norm=None, - dropout=None, - dropout_dim=self.dropout_dim, bias=self.bias, + conv_only=True, ) def forward(self, x): return torch.nn.functional.relu(x + self.conv2(self.conv1(x)), True) -class VQVAE(nn.Module): +class Encoder(nn.Module): """ - Single bottleneck implementation of Vector-Quantised Variational Autoencoder (VQ-VAE) as originally used in - Morphology-preserving Autoregressive 3D Generative Modelling of the Brain by Tudosiu et al. - (https://arxiv.org/pdf/2209.03177.pdf) and the original implementation that can be found at - https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L163/ + Encoder module for VQ-VAE. Args: spatial_dims: number of spatial spatial_dims. in_channels: number of input channels. - out_channels: number of output channels. - num_levels: number of levels that the network has. - downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the - following information stride (int), kernel_size (int), dilation (int) and padding (int). - upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the - following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). - If use_subpixel_conv is True, only the stride will be used for the last conv as the scale_factor. - num_res_layers: number of sequential residual layers at each level. + out_channels: number of channels in the latent space (embedding_dim). num_channels: number of channels at each level. + num_res_layers: number of sequential residual layers at each level. num_res_channels: number of channels in the residual layers at each level. - num_embeddings: VectorQuantization number of atomic elements in the codebook. - embedding_dim: VectorQuantization number of channels of the input and atomic elements. - commitment_cost: VectorQuantization commitment_cost. - decay: VectorQuantization decay. - epsilon: VectorQuantization epsilon. - adn_ordering: a string representing the ordering of activation, normalization, and dropout, e.g. "NDA". - act: activation type and arguments. + downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int) and padding (int). dropout: dropout ratio. - ddp_sync: whether to synchronize the codebook across processes. + act: activation type and arguments. """ - # < Python 3.9 TorchScript requirement for ModuleList - __constants__ = ["encoder", "quantizer", "decoder"] - def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, - num_levels: int = 3, - downsample_parameters: tuple[tuple[int, int, int, int], ...] = ((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1)), - upsample_parameters: tuple[tuple[int, int, int, int, int], ...] = ( - (2, 4, 1, 1, 0), - (2, 4, 1, 1, 0), - (2, 4, 1, 1, 0), - ), - num_res_layers: int = 3, - num_channels: Sequence[int] | int = (96, 96, 192), - num_res_channels: Sequence[int] | int = (96, 96, 192), - num_embeddings: int = 32, - embedding_dim: int = 64, - embedding_init: str = "normal", - commitment_cost: float = 0.25, - decay: float = 0.5, - epsilon: float = 1e-5, - adn_ordering: str = "NDA", - dropout: tuple | str | float | None = 0.1, - act: tuple | str | None = "RELU", - output_act: tuple | str | None = None, - ddp_sync: bool = True, - ): + num_channels: Sequence[int], + num_res_layers: int, + num_res_channels: Sequence[int], + downsample_parameters: Sequence[Sequence[int, int, int, int], ...], + dropout: float, + act: tuple | str | None, + ) -> None: super().__init__() - + self.spatial_dims = spatial_dims self.in_channels = in_channels self.out_channels = out_channels - self.spatial_dims = spatial_dims - - if isinstance(num_channels, int): - num_channels = ensure_tuple_rep(num_channels, num_levels) - if isinstance(num_res_channels, int): - num_res_channels = ensure_tuple_rep(num_res_channels, num_levels) - - assert ( - num_levels == len(downsample_parameters) - and num_levels == len(upsample_parameters) - and num_levels == len(num_channels) - and num_levels == len(num_res_channels) - ), ( - f"downsample_parameters, upsample_parameters, num_channels and num_res_channels must have the same number of" - f" elements as num_levels. But got {len(downsample_parameters)}, {len(upsample_parameters)}, " - f"{len(num_channels)} and {len(num_res_channels)} instead of {num_levels}." - ) - - self.num_levels = num_levels - self.downsample_parameters = downsample_parameters - self.upsample_parameters = upsample_parameters - self.num_res_layers = num_res_layers self.num_channels = num_channels + self.num_res_layers = num_res_layers self.num_res_channels = num_res_channels - + self.downsample_parameters = downsample_parameters self.dropout = dropout self.act = act - self.adn_ordering = adn_ordering - - self.output_act = output_act - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.embedding_init = embedding_init - self.commitment_cost = commitment_cost - self.decay = decay - self.epsilon = epsilon + blocks = [] - self.ddp_sync = ddp_sync - - self.encoder = self.construct_encoder() - self.quantizer = self.construct_quantizer() - self.decoder = self.construct_decoder() - - def construct_encoder(self) -> torch.nn.Sequential: - encoder = [] - - for idx in range(self.num_levels): - encoder.append( + for i in range(len(self.num_channels)): + blocks.append( Convolution( spatial_dims=self.spatial_dims, - in_channels=self.in_channels if idx == 0 else self.num_channels[idx - 1], - out_channels=self.num_channels[idx], - strides=self.downsample_parameters[idx][0], - kernel_size=self.downsample_parameters[idx][1], - adn_ordering=self.adn_ordering, + in_channels=self.in_channels if i == 0 else self.num_channels[i - 1], + out_channels=self.num_channels[i], + strides=self.downsample_parameters[i][0], + kernel_size=self.downsample_parameters[i][1], + adn_ordering="DA", act=self.act, - norm=None, - dropout=None if idx == 0 else self.dropout, + dropout=None if i == 0 else self.dropout, dropout_dim=1, - dilation=self.downsample_parameters[idx][2], - groups=1, - bias=True, - conv_only=False, - is_transposed=False, - padding=self.downsample_parameters[idx][3], - output_padding=None, + dilation=self.downsample_parameters[i][2], + padding=self.downsample_parameters[i][3], ) ) for _ in range(self.num_res_layers): - encoder.append( + blocks.append( VQVAEResidualUnit( spatial_dims=self.spatial_dims, - num_channels=self.num_channels[idx], - num_res_channels=self.num_res_channels[idx], - adn_ordering=self.adn_ordering, + num_channels=self.num_channels[i], + num_res_channels=self.num_res_channels[i], act=self.act, dropout=self.dropout, - dropout_dim=1, - bias=True, ) ) - encoder.append( + blocks.append( Convolution( spatial_dims=self.spatial_dims, in_channels=self.num_channels[len(self.num_channels) - 1], - out_channels=self.embedding_dim, + out_channels=self.out_channels, strides=1, kernel_size=3, - adn_ordering=self.adn_ordering, - act=None, - norm=None, - dropout=None, - dropout_dim=1, - dilation=1, - bias=True, - conv_only=True, - is_transposed=False, padding=1, - output_padding=None, + conv_only=True, ) ) - return torch.nn.Sequential(*encoder) + self.blocks = nn.ModuleList(blocks) - # TODO: Include lucidrains' vector quantizer as an option - def construct_quantizer(self) -> torch.nn.Module: - return VectorQuantizer( - quantizer=EMAQuantizer( - spatial_dims=self.spatial_dims, - num_embeddings=self.num_embeddings, - embedding_dim=self.embedding_dim, - commitment_cost=self.commitment_cost, - decay=self.decay, - epsilon=self.epsilon, - embedding_init=self.embedding_init, - ddp_sync=self.ddp_sync, - ) - ) + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class Decoder(nn.Module): + """ + Decoder module for VQ-VAE. + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of channels in the latent space (embedding_dim). + out_channels: number of output channels. + num_channels: number of channels at each level. + num_res_layers: number of sequential residual layers at each level. + num_res_channels: number of channels in the residual layers at each level. + upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). + dropout: dropout ratio. + act: activation type and arguments. + output_act: activation type and arguments for the output. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_channels: Sequence[int], + num_res_layers: int, + num_res_channels: Sequence[int], + upsample_parameters: Sequence[Sequence[int, int, int, int], ...], + dropout: float, + act: tuple | str | None, + output_act: tuple | str | None, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.num_channels = num_channels + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + self.upsample_parameters = upsample_parameters + self.dropout = dropout + self.act = act + self.output_act = output_act - def construct_decoder(self) -> torch.nn.Sequential: - decoder_num_channels = list(reversed(self.num_channels)) - decoder_num_res_channels = list(reversed(self.num_res_channels)) + reversed_num_channels = list(reversed(self.num_channels)) - decoder = [ + blocks = [] + blocks.append( Convolution( spatial_dims=self.spatial_dims, - in_channels=self.embedding_dim, - out_channels=decoder_num_channels[0], + in_channels=self.in_channels, + out_channels=reversed_num_channels[0], strides=1, kernel_size=3, - adn_ordering=self.adn_ordering, - act=None, - dropout=None, - norm=None, - dropout_dim=1, - dilation=1, - bias=True, - is_transposed=False, padding=1, - output_padding=None, + conv_only=True, ) - ] + ) - for idx in range(self.num_levels): + reversed_num_res_channels = list(reversed(self.num_res_channels)) + for i in range(len(self.num_channels)): for _ in range(self.num_res_layers): - decoder.append( + blocks.append( VQVAEResidualUnit( spatial_dims=self.spatial_dims, - num_channels=decoder_num_channels[idx], - num_res_channels=decoder_num_res_channels[idx], - adn_ordering=self.adn_ordering, + num_channels=reversed_num_channels[i], + num_res_channels=reversed_num_res_channels[i], act=self.act, dropout=self.dropout, - dropout_dim=1, - bias=True, ) ) - decoder.append( + blocks.append( Convolution( spatial_dims=self.spatial_dims, - in_channels=decoder_num_channels[idx], - out_channels=self.out_channels if idx == self.num_levels - 1 else decoder_num_channels[idx + 1], - strides=self.upsample_parameters[idx][0], - kernel_size=self.upsample_parameters[idx][1], - adn_ordering=self.adn_ordering, + in_channels=reversed_num_channels[i], + out_channels=self.out_channels if i == len(self.num_channels) - 1 else reversed_num_channels[i + 1], + strides=self.upsample_parameters[i][0], + kernel_size=self.upsample_parameters[i][1], + adn_ordering="DA", act=self.act, - dropout=self.dropout if idx != self.num_levels - 1 else None, + dropout=self.dropout if i != len(self.num_channels) - 1 else None, norm=None, - dropout_dim=1, - dilation=self.upsample_parameters[idx][2], - groups=1, - bias=True, - conv_only=idx == self.num_levels - 1, + dilation=self.upsample_parameters[i][2], + conv_only=i == len(self.num_channels) - 1, is_transposed=True, - padding=self.upsample_parameters[idx][3], - output_padding=self.upsample_parameters[idx][4], + padding=self.upsample_parameters[i][3], + output_padding=self.upsample_parameters[i][4], ) ) if self.output_act: - decoder.append(Act[self.output_act]()) + blocks.append(Act[self.output_act]()) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class VQVAE(nn.Module): + """ + Vector-Quantised Variational Autoencoder (VQ-VAE) used in Morphology-preserving Autoregressive 3D Generative + Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf) and the original implementation + that can be found at https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L163/ + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of input channels. + out_channels: number of output channels. + downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int) and padding (int). + upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). + num_res_layers: number of sequential residual layers at each level. + num_channels: number of channels at each level. + num_res_channels: number of channels in the residual layers at each level. + num_embeddings: VectorQuantization number of atomic elements in the codebook. + embedding_dim: VectorQuantization number of channels of the input and atomic elements. + commitment_cost: VectorQuantization commitment_cost. + decay: VectorQuantization decay. + epsilon: VectorQuantization epsilon. + act: activation type and arguments. + dropout: dropout ratio. + output_act: activation type and arguments for the output. + ddp_sync: whether to synchronize the codebook across processes. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_channels: Sequence[int] | int = (96, 96, 192), + num_res_layers: int = 3, + num_res_channels: Sequence[int] | int = (96, 96, 192), + downsample_parameters: Sequence[Sequence[int, int, int, int], ...] + | Sequence[int, int, int, int] = ((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1)), + upsample_parameters: Sequence[Sequence[int, int, int, int, int], ...] + | Sequence[int, int, int, int] = ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + num_embeddings: int = 32, + embedding_dim: int = 64, + embedding_init: str = "normal", + commitment_cost: float = 0.25, + decay: float = 0.5, + epsilon: float = 1e-5, + dropout: float = 0.0, + act: tuple | str | None = Act.RELU, + output_act: tuple | str | None = None, + ddp_sync: bool = True, + ): + super().__init__() - return torch.nn.Sequential(*decoder) + self.in_channels = in_channels + self.out_channels = out_channels + self.spatial_dims = spatial_dims + self.num_channels = num_channels + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + + if isinstance(num_res_channels, int): + num_res_channels = ensure_tuple_rep(num_res_channels, len(num_channels)) + + if len(num_res_channels) != len(num_channels): + raise ValueError( + "`num_res_channels` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if not all(isinstance(values, (int, Sequence)) for values in downsample_parameters): + raise ValueError("`downsample_parameters` should be a single tuple of integer or a tuple of tuples.") + + if not all(isinstance(values, (int, Sequence)) for values in upsample_parameters): + raise ValueError("`upsample_parameters` should be a single tuple of integer or a tuple of tuples.") + + if all(isinstance(values, int) for values in upsample_parameters): + upsample_parameters = (upsample_parameters,) * len(num_channels) + + if all(isinstance(values, int) for values in downsample_parameters): + downsample_parameters = (downsample_parameters,) * len(num_channels) + + for parameter in downsample_parameters: + if len(parameter) != 4: + raise ValueError("`downsample_parameters` should be a tuple of tuples with 4 integers.") + + for parameter in upsample_parameters: + if len(parameter) != 5: + raise ValueError("`upsample_parameters` should be a tuple of tuples with 5 integers.") + + if len(downsample_parameters) != len(num_channels): + raise ValueError( + "`downsample_parameters` should be a tuple of tuples with the same length as `num_channels`." + ) + + if len(upsample_parameters) != len(num_channels): + raise ValueError( + "`upsample_parameters` should be a tuple of tuples with the same length as `num_channels`." + ) + + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + + self.encoder = Encoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=embedding_dim, + num_channels=num_channels, + num_res_layers=num_res_layers, + num_res_channels=num_res_channels, + downsample_parameters=downsample_parameters, + dropout=dropout, + act=act, + ) + + self.decoder = Decoder( + spatial_dims=spatial_dims, + in_channels=embedding_dim, + out_channels=out_channels, + num_channels=num_channels, + num_res_layers=num_res_layers, + num_res_channels=num_res_channels, + upsample_parameters=upsample_parameters, + dropout=dropout, + act=act, + output_act=output_act, + ) + + self.quantizer = VectorQuantizer( + quantizer=EMAQuantizer( + spatial_dims=spatial_dims, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + commitment_cost=commitment_cost, + decay=decay, + epsilon=epsilon, + embedding_init=embedding_init, + ddp_sync=ddp_sync, + ) + ) def encode(self, images: torch.Tensor) -> torch.Tensor: return self.encoder(images) diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index 2d4dbb20..9623c76a 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -54,12 +54,11 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_levels": 2, - "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), - "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), - "num_res_layers": 1, "num_channels": [4, 4], + "num_res_layers": 1, "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), "num_embeddings": 16, "embedding_dim": 3, }, diff --git a/tests/test_vqvae.py b/tests/test_vqvae.py index 8c0226ea..9d9ccade 100644 --- a/tests/test_vqvae.py +++ b/tests/test_vqvae.py @@ -26,146 +26,78 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_levels": 2, - "downsample_parameters": [(2, 4, 1, 1)] * 2, - "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, + "num_channels": (4, 4), "num_res_layers": 1, - "num_channels": 8, - "num_res_channels": [8, 8], - "num_embeddings": 16, + "num_res_channels": (4, 4), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, "embedding_dim": 8, - "embedding_init": "normal", - "commitment_cost": 0.25, - "decay": 0.5, - "epsilon": 1e-5, - "adn_ordering": "NDA", - "dropout": 0.1, - "act": "RELU", - "output_act": None, }, - (1, 1, 16, 16), - (1, 1, 16, 16), + (1, 1, 8, 8), + (1, 1, 8, 8), ], [ { "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_levels": 2, - "downsample_parameters": [(2, 4, 1, 1)] * 2, - "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, + "num_channels": (4, 4), "num_res_layers": 1, - "num_channels": 8, - "num_res_channels": 8, - "num_embeddings": 16, + "num_res_channels": 4, + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, "embedding_dim": 8, - "embedding_init": "normal", - "commitment_cost": 0.25, - "decay": 0.5, - "epsilon": 1e-5, - "adn_ordering": "NDA", - "dropout": 0.1, - "act": "RELU", - "output_act": None, }, - (1, 1, 16, 16), - (1, 1, 16, 16), + (1, 1, 8, 8), + (1, 1, 8, 8), ], [ { "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_levels": 2, - "downsample_parameters": [(2, 4, 1, 1)] * 2, - "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, + "num_channels": (4, 4), "num_res_layers": 1, - "num_channels": [8, 8], - "num_res_channels": [8, 8], - "num_embeddings": 16, + "num_res_channels": (4, 4), + "downsample_parameters": (2, 4, 1, 1), + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, "embedding_dim": 8, - "embedding_init": "normal", - "commitment_cost": 0.25, - "decay": 0.5, - "epsilon": 1e-5, - "adn_ordering": "NDA", - "dropout": 0.1, - "act": "RELU", - "output_act": None, }, - (1, 1, 16, 16), - (1, 1, 16, 16), + (1, 1, 8, 8), + (1, 1, 8, 8), ], [ { - "spatial_dims": 3, + "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_levels": 2, - "downsample_parameters": [(2, 4, 1, 1)] * 2, - "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, + "num_channels": (4, 4), "num_res_layers": 1, - "num_channels": [8, 8], - "num_res_channels": [8, 8], - "num_embeddings": 16, + "num_res_channels": (4, 4), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": (2, 4, 1, 1, 0), + "num_embeddings": 8, "embedding_dim": 8, - "embedding_init": "normal", - "commitment_cost": 0.25, - "decay": 0.5, - "epsilon": 1e-5, - "adn_ordering": "NDA", - "dropout": 0.1, - "act": "RELU", - "output_act": None, }, - (1, 1, 16, 16, 16), - (1, 1, 16, 16, 16), + (1, 1, 8, 8), + (1, 1, 8, 8), ], ] -# 1-channel 2D, should fail because of number of levels, number of downsamplings, number of upsamplings mismatch. -TEST_CASE_FAIL = { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "num_levels": 3, - "downsample_parameters": [(2, 4, 1, 1)] * 2, - "upsample_parameters": [(2, 4, 1, 1, 0)] * 4, - "num_res_layers": 1, - "num_channels": [8, 8], - "num_res_channels": [8, 8], - "num_embeddings": 16, - "embedding_dim": 8, - "embedding_init": "normal", - "commitment_cost": 0.25, - "decay": 0.5, - "epsilon": 1e-5, - "adn_ordering": "NDA", - "dropout": 0.1, - "act": "RELU", - "output_act": None, -} - TEST_LATENT_SHAPE = { "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_levels": 2, - "downsample_parameters": [(2, 4, 1, 1)] * 2, - "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, "num_res_layers": 1, - "num_channels": [8, 8], - "num_res_channels": [8, 8], + "num_channels": (8, 8), + "num_res_channels": (8, 8), "num_embeddings": 16, "embedding_dim": 8, - "embedding_init": "normal", - "commitment_cost": 0.25, - "decay": 0.5, - "epsilon": 1e-5, - "adn_ordering": "NDA", - "dropout": 0.1, - "act": "RELU", - "output_act": None, } @@ -186,30 +118,101 @@ def test_script(self): spatial_dims=2, in_channels=1, out_channels=1, - num_levels=2, - downsample_parameters=tuple([(2, 4, 1, 1)] * 2), - upsample_parameters=tuple([(2, 4, 1, 1, 0)] * 2), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, num_res_layers=1, - num_channels=[8, 8], - num_res_channels=[8, 8], + num_channels=(8, 8), + num_res_channels=(8, 8), num_embeddings=16, embedding_dim=8, - embedding_init="normal", - commitment_cost=0.25, - decay=0.5, - epsilon=1e-5, - adn_ordering="NDA", - dropout=0.1, - act="RELU", - output_act=None, ddp_sync=False, ) test_data = torch.randn(1, 1, 16, 16) test_script_save(net, test_data) - def test_level_upsample_downsample_difference(self): - with self.assertRaises(AssertionError): - VQVAE(**TEST_CASE_FAIL) + def test_num_channels_not_same_size_of_num_res_channels(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_num_channels_not_same_size_of_downsample_parameters(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1, 1),) * 3, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_num_channels_not_same_size_of_upsample_parameters(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 3, + ) + + def test_downsample_parameters_not_sequence_or_int(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=(("test", 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_upsample_parameters_not_sequence_or_int(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=(("test", 4, 1, 1, 0),) * 2, + ) + + def test_downsample_parameter_length_different_4(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 3, + ) + + def test_upsample_parameter_length_different_5(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0, 1),) * 3, + ) def test_encode_shape(self): device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py index edb152a3..0c10a31a 100644 --- a/tests/test_vqvaetransformer_inferer.py +++ b/tests/test_vqvaetransformer_inferer.py @@ -26,12 +26,11 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_levels": 2, + "num_channels": (8, 8), + "num_res_channels": (8, 8), "downsample_parameters": ((2, 4, 1, 1),) * 2, "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, "num_res_layers": 1, - "num_channels": 8, - "num_res_channels": [8, 8], "num_embeddings": 16, "embedding_dim": 8, }, @@ -52,12 +51,11 @@ "spatial_dims": 3, "in_channels": 1, "out_channels": 1, - "num_levels": 2, + "num_channels": (8, 8), + "num_res_channels": (8, 8), "downsample_parameters": ((2, 4, 1, 1),) * 2, "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, "num_res_layers": 1, - "num_channels": 8, - "num_res_channels": [8, 8], "num_embeddings": 16, "embedding_dim": 8, }, @@ -100,12 +98,11 @@ def test_sample(self): spatial_dims=2, in_channels=1, out_channels=1, - num_levels=2, + num_channels=(8, 8), + num_res_channels=(8, 8), downsample_parameters=((2, 4, 1, 1),) * 2, upsample_parameters=((2, 4, 1, 1, 0),) * 2, num_res_layers=1, - num_channels=8, - num_res_channels=(8, 8), num_embeddings=16, embedding_dim=8, ) diff --git a/tutorials/generative/2d_vqgan/2d_vqgan_tutorial.ipynb b/tutorials/generative/2d_vqgan/2d_vqgan_tutorial.ipynb index c2e54256..60063621 100644 --- a/tutorials/generative/2d_vqgan/2d_vqgan_tutorial.ipynb +++ b/tutorials/generative/2d_vqgan/2d_vqgan_tutorial.ipynb @@ -349,12 +349,11 @@ " spatial_dims=2,\n", " in_channels=1,\n", " out_channels=1,\n", + " num_channels=(256, 512),\n", + " num_res_channels=512,\n", " num_res_layers=2,\n", - " num_levels=2,\n", " downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1)),\n", " upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),\n", - " num_channels=[256, 512],\n", - " num_res_channels=[256, 512],\n", " num_embeddings=256,\n", " embedding_dim=32,\n", ")\n", diff --git a/tutorials/generative/2d_vqgan/2d_vqgan_tutorial.py b/tutorials/generative/2d_vqgan/2d_vqgan_tutorial.py index a2ebdbf7..89dc0725 100644 --- a/tutorials/generative/2d_vqgan/2d_vqgan_tutorial.py +++ b/tutorials/generative/2d_vqgan/2d_vqgan_tutorial.py @@ -165,12 +165,11 @@ spatial_dims=2, in_channels=1, out_channels=1, + num_channels=(256, 512), + num_res_channels=512, num_res_layers=2, - num_levels=2, downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1)), upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), - num_channels=[256, 512], - num_res_channels=[256, 512], num_embeddings=256, embedding_dim=32, ) diff --git a/tutorials/generative/2d_vqvae/2d_vqvae_tutorial.ipynb b/tutorials/generative/2d_vqvae/2d_vqvae_tutorial.ipynb index cd7feeae..714ba358 100644 --- a/tutorials/generative/2d_vqvae/2d_vqvae_tutorial.ipynb +++ b/tutorials/generative/2d_vqvae/2d_vqvae_tutorial.ipynb @@ -501,12 +501,11 @@ " spatial_dims=2,\n", " in_channels=1,\n", " out_channels=1,\n", + " num_channels=(256, 256),\n", + " num_res_channels=256,\n", + " num_res_layers=2,\n", " downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1)),\n", " upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),\n", - " num_res_layers=2,\n", - " num_levels=2,\n", - " num_channels=256,\n", - " num_res_channels=256,\n", " num_embeddings=256,\n", " embedding_dim=32,\n", ")\n", @@ -615,13 +614,7 @@ "Epoch 69: 100%|██████████████| 125/125 [00:32<00:00, 3.89it/s, recons_loss=0.0162, quantization_loss=2.33e-5]\n", "Epoch 70: 100%|███████████████| 125/125 [00:32<00:00, 3.88it/s, recons_loss=0.0162, quantization_loss=2.5e-5]\n", "Epoch 71: 100%|██████████████| 125/125 [00:32<00:00, 3.89it/s, recons_loss=0.0168, quantization_loss=2.34e-5]\n", - "Epoch 72: 100%|██████████████| 125/125 [00:32<00:00, 3.89it/s, recons_loss=0.0171, quantization_loss=2.01e-5]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ + "Epoch 72: 100%|██████████████| 125/125 [00:32<00:00, 3.89it/s, recons_loss=0.0171, quantization_loss=2.01e-5]\n", "Epoch 73: 100%|██████████████| 125/125 [00:32<00:00, 3.89it/s, recons_loss=0.0166, quantization_loss=2.05e-5]\n", "Epoch 74: 100%|██████████████| 125/125 [00:32<00:00, 3.89it/s, recons_loss=0.0165, quantization_loss=2.36e-5]\n", "Epoch 75: 100%|██████████████| 125/125 [00:32<00:00, 3.89it/s, recons_loss=0.0161, quantization_loss=1.96e-5]\n", @@ -890,4 +883,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/tutorials/generative/2d_vqvae/2d_vqvae_tutorial.py b/tutorials/generative/2d_vqvae/2d_vqvae_tutorial.py index 79b3870b..2aae29fb 100644 --- a/tutorials/generative/2d_vqvae/2d_vqvae_tutorial.py +++ b/tutorials/generative/2d_vqvae/2d_vqvae_tutorial.py @@ -118,12 +118,11 @@ spatial_dims=2, in_channels=1, out_channels=1, + num_channels=(256, 256), + num_res_channels=256, + num_res_layers=2, downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1)), upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), - num_res_layers=2, - num_levels=2, - num_channels=256, - num_res_channels=256, num_embeddings=256, embedding_dim=32, ) diff --git a/tutorials/generative/3d_vqvae/3d_vqvae_tutorial.ipynb b/tutorials/generative/3d_vqvae/3d_vqvae_tutorial.ipynb index a898f0cc..a20a5393 100644 --- a/tutorials/generative/3d_vqvae/3d_vqvae_tutorial.ipynb +++ b/tutorials/generative/3d_vqvae/3d_vqvae_tutorial.ipynb @@ -471,12 +471,11 @@ " spatial_dims=3,\n", " in_channels=1,\n", " out_channels=1,\n", + " num_channels=(256, 256),\n", + " num_res_channels=256,\n", " num_res_layers=2,\n", - " num_levels=2,\n", " downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1)),\n", " upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),\n", - " num_channels=[256, 256],\n", - " num_res_channels=[256, 256],\n", " num_embeddings=256,\n", " embedding_dim=32,\n", ")\n", diff --git a/tutorials/generative/3d_vqvae/3d_vqvae_tutorial.py b/tutorials/generative/3d_vqvae/3d_vqvae_tutorial.py index 1c057a3b..95786e15 100644 --- a/tutorials/generative/3d_vqvae/3d_vqvae_tutorial.py +++ b/tutorials/generative/3d_vqvae/3d_vqvae_tutorial.py @@ -123,12 +123,11 @@ spatial_dims=3, in_channels=1, out_channels=1, + num_channels=(256, 256), + num_res_channels=256, num_res_layers=2, - num_levels=2, downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1)), upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), - num_channels=[256, 256], - num_res_channels=[256, 256], num_embeddings=256, embedding_dim=32, )