diff --git a/generative/losses/kld_loss.py b/generative/losses/kld_loss.py deleted file mode 100644 index ebcaf52b..00000000 --- a/generative/losses/kld_loss.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import torch -import torch.nn as nn - - -class KLDLoss(nn.Module): - def forward(self, mu, logvar): - return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) diff --git a/generative/networks/nets/spade_network.py b/generative/networks/nets/spade_network.py index b2ff2833..8d4808ab 100644 --- a/generative/networks/nets/spade_network.py +++ b/generative/networks/nets/spade_network.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import Sequence, Union +from typing import Sequence import numpy as np import torch @@ -21,9 +21,15 @@ from monai.networks.layers import Act from monai.utils.enums import StrEnum -from generative.losses.kld_loss import KLDLoss from generative.networks.blocks.spade_norm import SPADE +class KLDLoss(nn.Module): + """ + Computes the Kullback-Leibler divergence between a normal distribution with mean mu and variance logvar and + one with mean 0 and variance 1. + """ + def forward(self, mu, logvar): + return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) class UpsamplingModes(StrEnum): bicubic = "bicubic" @@ -52,7 +58,7 @@ def __init__( out_channels: int, label_nc: int, spade_intermediate_channels: int = 128, - norm: Union[str, tuple] = "INSTANCE", + norm: str | tuple = "INSTANCE", kernel_size: int = 3, ): @@ -146,8 +152,8 @@ def __init__( num_channels: Sequence[int], input_shape: Sequence[int], kernel_size: int = 3, - norm: Union[str, tuple] = "INSTANCE", - act: Union[str, tuple] = (Act.LEAKYRELU, {"negative_slope": 0.2}), + norm: str | tuple = "INSTANCE", + act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), ): super().__init__() @@ -240,12 +246,12 @@ def __init__( label_nc: int, input_shape: Sequence[int], num_channels: Sequence[int], - z_dim: Union[int, None] = None, + z_dim: int | None = None, is_gan: bool = False, spade_intermediate_channels: int = 128, - norm: Union[str, tuple] = "INSTANCE", - act: Union[str, tuple, None] = (Act.LEAKYRELU, {"negative_slope": 0.2}), - last_act: Union[str, tuple, None] = (Act.LEAKYRELU, {"negative_slope": 0.2}), + norm: str | tuple = "INSTANCE", + act: str | tuple | None = (Act.LEAKYRELU, {"negative_slope": 0.2}), + last_act: str | tuple | None = (Act.LEAKYRELU, {"negative_slope": 0.2}), kernel_size: int = 3, upsampling_mode: str = UpsamplingModes.nearest.value, ): @@ -346,12 +352,12 @@ def __init__( label_nc: int, input_shape: Sequence[int], num_channels: Sequence[int], - z_dim: Union[int, None] = None, + z_dim: int | None = None, is_vae: bool = True, spade_intermediate_channels: int = 128, - norm: Union[str, tuple] = "INSTANCE", - act: Union[str, tuple, None] = (Act.LEAKYRELU, {"negative_slope": 0.2}), - last_act: Union[str, tuple, None] = (Act.LEAKYRELU, {"negative_slope": 0.2}), + norm: str | tuple = "INSTANCE", + act: str | tuple | None = (Act.LEAKYRELU, {"negative_slope": 0.2}), + last_act: str | tuple | None = (Act.LEAKYRELU, {"negative_slope": 0.2}), kernel_size: int = 3, upsampling_mode: str = UpsamplingModes.nearest.value, ): @@ -399,7 +405,7 @@ def __init__( upsampling_mode=upsampling_mode, ) - def forward(self, seg: torch.Tensor, x: Union[torch.Tensor, None] = None): + def forward(self, seg: torch.Tensor, x: torch.Tensor | None = None): z = None if self.is_vae: z_mu, z_logvar = self.encoder(x) @@ -413,6 +419,6 @@ def encode(self, x: torch.Tensor): return self.encoder.encode(x) - def decode(self, seg: torch.Tensor, z: Union[torch.Tensor, None] = None): + def decode(self, seg: torch.Tensor, z: torch.Tensor | None = None): return self.decoder(seg, z)