Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged

Fixes #211

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions generative/networks/nets/vqvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,26 +100,24 @@ class VQVAE(nn.Module):
spatial_dims: number of spatial spatial_dims.
in_channels: number of input channels.
out_channels: number of output channels.
num_levels: number of levels that the network has. Defaults to 3.
num_levels: number of levels that the network has.
downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the
following information stride (int), kernel_size (int), dilation(int) and padding (int).
Defaults to ((2,4,1,1),(2,4,1,1),(2,4,1,1)).
following information stride (int), kernel_size (int), dilation (int) and padding (int).
upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the
following information stride (int), kernel_size (int), dilation(int), padding (int), output_padding (int).
following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int).
If use_subpixel_conv is True, only the stride will be used for the last conv as the scale_factor.
Defaults to ((2,4,1,1,0),(2,4,1,1,0),(2,4,1,1,0)).
num_res_layers: number of sequential residual layers at each level. Defaults to 3.
num_channels: number of channels at the deepest level, besides that is num_channels//2 . Defaults to 192.
num_res_channels: number of channels in the residual layers. Defaults to 64.
num_embeddings: VectorQuantization number of atomic elements in the codebook. Defaults to 32.
embedding_dim: VectorQuantization number of channels of the input and atomic elements. Defaults to 64.
commitment_cost: VectorQuantization commitment_cost. Defaults to 0.25.
decay: VectorQuantization decay. Defaults to 0.5.
epsilon: VectorQuantization epsilon. Defaults to 1e-5 as.
adn_ordering: a string representing the ordering of activation, normalization, and dropout. Defaults to "NDA".
act: activation type and arguments. Defaults to Relu.
dropout: dropout ratio. Defaults to 0.1.
ddp_sync: whether to synchronize the codebook across processes. Defaults to True.
num_res_layers: number of sequential residual layers at each level.
num_channels: number of channels at each level.
num_res_channels: number of channels in the residual layers at each level.
num_embeddings: VectorQuantization number of atomic elements in the codebook.
embedding_dim: VectorQuantization number of channels of the input and atomic elements.
commitment_cost: VectorQuantization commitment_cost.
decay: VectorQuantization decay.
epsilon: VectorQuantization epsilon.
adn_ordering: a string representing the ordering of activation, normalization, and dropout, e.g. "NDA".
act: activation type and arguments.
dropout: dropout ratio.
ddp_sync: whether to synchronize the codebook across processes.
"""

# < Python 3.9 TorchScript requirement for ModuleList
Expand Down Expand Up @@ -166,7 +164,7 @@ def __init__(
), (
f"downsample_parameters, upsample_parameters, num_channels and num_res_channels must have the same number of"
f" elements as num_levels. But got {len(downsample_parameters)}, {len(upsample_parameters)}, "
f"{len(num_res_channels)} and {len(num_res_channels)} instead of {num_levels}."
f"{len(num_channels)} and {len(num_res_channels)} instead of {num_levels}."
)

self.num_levels = num_levels
Expand Down