diff --git a/generative/networks/blocks/encoder_modules.py b/generative/networks/blocks/encoder_modules.py index 62eab739..f51dc581 100644 --- a/generative/networks/blocks/encoder_modules.py +++ b/generative/networks/blocks/encoder_modules.py @@ -74,7 +74,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.remap_output: x = self.channel_mapper(x) - for stage in range(self.n_stages): + for _ in range(self.n_stages): x = self.interpolator(x, scale_factor=self.multiplier) return x diff --git a/generative/networks/blocks/spade_norm.py b/generative/networks/blocks/spade_norm.py index 0fe735e8..2e7bbcef 100644 --- a/generative/networks/blocks/spade_norm.py +++ b/generative/networks/blocks/spade_norm.py @@ -39,11 +39,12 @@ def __init__( spatial_dims: int = 2, hidden_channels: int = 64, norm: str | tuple = "INSTANCE", - norm_params: dict = {}, + norm_params: dict | None = None, ) -> None: - super().__init__() + if norm_params is None: + norm_params = {} if len(norm_params) != 0: norm = (norm, norm_params) self.param_free_norm = ADN( diff --git a/generative/networks/nets/spade_network.py b/generative/networks/nets/spade_network.py index 8d4808ab..0d5d18bb 100644 --- a/generative/networks/nets/spade_network.py +++ b/generative/networks/nets/spade_network.py @@ -23,21 +23,24 @@ from generative.networks.blocks.spade_norm import SPADE + class KLDLoss(nn.Module): """ Computes the Kullback-Leibler divergence between a normal distribution with mean mu and variance logvar and one with mean 0 and variance 1. """ + def forward(self, mu, logvar): return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + class UpsamplingModes(StrEnum): bicubic = "bicubic" nearest = "nearest" bilinear = "bilinear" -class SPADE_ResNetBlock(nn.Module): +class SPADEResNetBlock(nn.Module): """ Creates a Residual Block with SPADE normalisation. @@ -61,7 +64,6 @@ def __init__( norm: str | tuple = "INSTANCE", kernel_size: int = 3, ): - super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -128,7 +130,7 @@ def shortcut(self, x, seg): return x_s -class SPADE_Encoder(nn.Module): +class SPADEEncoder(nn.Module): """ Encoding branch of a VAE compatible with a SPADE-like generator @@ -155,7 +157,6 @@ def __init__( norm: str | tuple = "INSTANCE", act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), ): - super().__init__() self.in_channels = in_channels self.z_dim = z_dim @@ -172,7 +173,7 @@ def __init__( self.latent_spatial_shape = [s_ // (2 ** len(self.num_channels)) for s_ in self.input_shape] blocks = [] ch_init = self.in_channels - for ch_ind, ch_value in enumerate(num_channels): + for _, ch_value in enumerate(num_channels): blocks.append( Convolution( spatial_dims=spatial_dims, @@ -211,13 +212,12 @@ def encode(self, x): return self.reparameterize(mu, logvar) def reparameterize(self, mu, logvar): - std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps.mul(std) + mu -class SPADE_Decoder(nn.Module): +class SPADEDecoder(nn.Module): """ Decoder branch of a SPADE-like generator. It can be used independently, without an encoding branch, behaving like a GAN, or coupled to a SPADE encoder. @@ -255,7 +255,6 @@ def __init__( kernel_size: int = 3, upsampling_mode: str = UpsamplingModes.nearest.value, ): - super().__init__() self.is_gan = is_gan self.out_channels = out_channels @@ -281,7 +280,7 @@ def __init__( self.upsampling = torch.nn.Upsample(scale_factor=2, mode=upsampling_mode) for ch_ind, ch_value in enumerate(num_channels[:-1]): blocks.append( - SPADE_ResNetBlock( + SPADEResNetBlock( spatial_dims=spatial_dims, in_channels=ch_value, out_channels=num_channels[ch_ind + 1], @@ -321,10 +320,11 @@ def forward(self, seg, z: torch.Tensor = None): return x -class SPADE_Net(nn.Module): +class SPADENet(nn.Module): """ - SPADE Network, implemented based on the code by Park, T et al. in "Semantic Image Synthesis with Spatially-Adaptive Normalization" + SPADE Network, implemented based on the code by Park, T et al. in + "Semantic Image Synthesis with Spatially-Adaptive Normalization" (https://github.com/NVlabs/SPADE) Args: @@ -361,7 +361,6 @@ def __init__( kernel_size: int = 3, upsampling_mode: str = UpsamplingModes.nearest.value, ): - super().__init__() self.is_vae = is_vae if self.is_vae and z_dim is None: @@ -375,7 +374,7 @@ def __init__( self.kld_loss = KLDLoss() if self.is_vae: - self.encoder = SPADE_Encoder( + self.encoder = SPADEEncoder( spatial_dims=spatial_dims, in_channels=in_channels, z_dim=z_dim, @@ -389,7 +388,7 @@ def __init__( decoder_channels = num_channels decoder_channels.reverse() - self.decoder = SPADE_Decoder( + self.decoder = SPADEDecoder( spatial_dims=spatial_dims, out_channels=out_channels, label_nc=label_nc, @@ -416,9 +415,7 @@ def forward(self, seg: torch.Tensor, x: torch.Tensor | None = None): return (self.decoder(seg, z),) def encode(self, x: torch.Tensor): - return self.encoder.encode(x) def decode(self, seg: torch.Tensor, z: torch.Tensor | None = None): - return self.decoder(seg, z) diff --git a/tests/test_spade_vaegan.py b/tests/test_spade_vaegan.py index e030b81e..048549dd 100644 --- a/tests/test_spade_vaegan.py +++ b/tests/test_spade_vaegan.py @@ -124,7 +124,7 @@ def test_shape_wrong(self): We input an input shape that isn't divisible by 2**(n downstream steps) """ with self.assertRaises(ValueError): - net = SPADE_Net(1, 1, 8, [16, 16], [16, 32, 64, 128], 16, True) + _ = SPADE_Net(1, 1, 8, [16, 16], [16, 32, 64, 128], 16, True) if __name__ == "__main__":