diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index f1a75636..d1eac073 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -543,9 +543,33 @@ def __init__( attention_levels=attention_levels, with_nonlocal_attn=with_decoder_nonlocal_attn, ) - self.quant_conv_mu = Convolution(spatial_dims, latent_channels, latent_channels, 1, conv_only=True) - self.quant_conv_log_sigma = Convolution(spatial_dims, latent_channels, latent_channels, 1, conv_only=True) - self.post_quant_conv = Convolution(spatial_dims, latent_channels, latent_channels, 1, conv_only=True) + self.quant_conv_mu = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.quant_conv_log_sigma = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.post_quant_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) self.latent_channels = latent_channels def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: