From ebdc1cbc8be400dc74d6dbdbf3e10b2087c03d1b Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sun, 8 Jan 2023 18:10:51 +0000 Subject: [PATCH] Fix kernel_size in quant_conv and post_quant_conv layers Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/autoencoderkl.py | 30 ++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index f1a75636..d1eac073 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -543,9 +543,33 @@ def __init__( attention_levels=attention_levels, with_nonlocal_attn=with_decoder_nonlocal_attn, ) - self.quant_conv_mu = Convolution(spatial_dims, latent_channels, latent_channels, 1, conv_only=True) - self.quant_conv_log_sigma = Convolution(spatial_dims, latent_channels, latent_channels, 1, conv_only=True) - self.post_quant_conv = Convolution(spatial_dims, latent_channels, latent_channels, 1, conv_only=True) + self.quant_conv_mu = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.quant_conv_log_sigma = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.post_quant_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) self.latent_channels = latent_channels def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: