diff --git a/generative/losses/kld_loss.py b/generative/losses/kld_loss.py new file mode 100644 index 00000000..ebcaf52b --- /dev/null +++ b/generative/losses/kld_loss.py @@ -0,0 +1,20 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn as nn + + +class KLDLoss(nn.Module): + def forward(self, mu, logvar): + return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) diff --git a/generative/networks/blocks/spade_norm.py b/generative/networks/blocks/spade_norm.py new file mode 100644 index 00000000..0fe735e8 --- /dev/null +++ b/generative/networks/blocks/spade_norm.py @@ -0,0 +1,95 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F +from monai.networks.blocks import ADN, Convolution + + +class SPADE(nn.Module): + """ + SPADE normalisation block based on the 2019 paper by Park et al. (doi: https://doi.org/10.48550/arXiv.1903.07291) + + Args: + label_nc: number of semantic labels + norm_nc: number of output channels + kernel_size: kernel size + spatial_dims: number of spatial dimensions + hidden_channels: number of channels in the intermediate gamma and beta layers + norm: type of base normalisation used before applying the SPADE normalisation + norm_params: parameters for the base normalisation + """ + + def __init__( + self, + label_nc: int, + norm_nc: int, + kernel_size: int = 3, + spatial_dims: int = 2, + hidden_channels: int = 64, + norm: str | tuple = "INSTANCE", + norm_params: dict = {}, + ) -> None: + + super().__init__() + + if len(norm_params) != 0: + norm = (norm, norm_params) + self.param_free_norm = ADN( + act=None, dropout=0.0, norm=norm, norm_dim=spatial_dims, ordering="N", in_channels=norm_nc + ) + self.mlp_shared = Convolution( + spatial_dims=spatial_dims, + in_channels=label_nc, + out_channels=hidden_channels, + kernel_size=kernel_size, + norm=None, + padding=kernel_size // 2, + act="LEAKYRELU", + ) + self.mlp_gamma = Convolution( + spatial_dims=spatial_dims, + in_channels=hidden_channels, + out_channels=norm_nc, + kernel_size=kernel_size, + padding=kernel_size // 2, + act=None, + ) + self.mlp_beta = Convolution( + spatial_dims=spatial_dims, + in_channels=hidden_channels, + out_channels=norm_nc, + kernel_size=kernel_size, + padding=kernel_size // 2, + act=None, + ) + + def forward(self, x: torch.Tensor, segmap: torch.Tensor) -> torch.Tensor: + """ + Args: + x: input tensor + segmap: input segmentation map (bxcx[spatial-dimensions]) where c is the number of semantic channels. + The map will be interpolated to the dimension of x internally. + """ + + # Part 1. generate parameter-free normalized activations + normalized = self.param_free_norm(x) + + # Part 2. produce scaling and bias conditioned on semantic map + segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + out = normalized * (1 + gamma) + beta + return out diff --git a/generative/networks/nets/spade_network.py b/generative/networks/nets/spade_network.py new file mode 100644 index 00000000..b2ff2833 --- /dev/null +++ b/generative/networks/nets/spade_network.py @@ -0,0 +1,418 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Sequence, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from monai.networks.blocks import Convolution +from monai.networks.layers import Act +from monai.utils.enums import StrEnum + +from generative.losses.kld_loss import KLDLoss +from generative.networks.blocks.spade_norm import SPADE + + +class UpsamplingModes(StrEnum): + bicubic = "bicubic" + nearest = "nearest" + bilinear = "bilinear" + + +class SPADE_ResNetBlock(nn.Module): + """ + Creates a Residual Block with SPADE normalisation. + + Args: + spatial_dims: number of spatial dimensions + in_channels: number of input channels + out_channels: number of output channels + label_nc: number of semantic channels that will be taken into account in SPADE normalisation blocks + spade_intermediate_channels: number of intermediate channels in the middle conv. layers in SPADE normalisation blocks + norm: base normalisation type used on top of SPADE + kernel_size: convolutional kernel size + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + label_nc: int, + spade_intermediate_channels: int = 128, + norm: Union[str, tuple] = "INSTANCE", + kernel_size: int = 3, + ): + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.int_channels = min(self.in_channels, self.out_channels) + self.learned_shortcut = self.in_channels != self.out_channels + self.conv_0 = Convolution( + spatial_dims=spatial_dims, in_channels=self.in_channels, out_channels=self.int_channels, act=None, norm=None + ) + self.conv_1 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.int_channels, + out_channels=self.out_channels, + act=None, + norm=None, + ) + self.activation = nn.LeakyReLU(0.2, False) + self.norm_0 = SPADE( + label_nc=label_nc, + norm_nc=self.in_channels, + kernel_size=kernel_size, + spatial_dims=spatial_dims, + hidden_channels=spade_intermediate_channels, + norm=norm, + ) + self.norm_1 = SPADE( + label_nc=label_nc, + norm_nc=self.int_channels, + kernel_size=kernel_size, + spatial_dims=spatial_dims, + hidden_channels=spade_intermediate_channels, + norm=norm, + ) + + if self.learned_shortcut: + self.conv_s = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + act=None, + norm=None, + kernel_size=1, + ) + self.norm_s = SPADE( + label_nc=label_nc, + norm_nc=self.in_channels, + kernel_size=kernel_size, + spatial_dims=spatial_dims, + hidden_channels=spade_intermediate_channels, + norm=norm, + ) + + def forward(self, x, seg): + x_s = self.shortcut(x, seg) + dx = self.conv_0(self.activation(self.norm_0(x, seg))) + dx = self.conv_1(self.activation(self.norm_1(dx, seg))) + out = x_s + dx + return out + + def shortcut(self, x, seg): + if self.learned_shortcut: + x_s = self.conv_s(self.norm_s(x, seg)) + else: + x_s = x + return x_s + + +class SPADE_Encoder(nn.Module): + """ + Encoding branch of a VAE compatible with a SPADE-like generator + + Args: + spatial_dims: number of spatial dimensions + in_channels: number of input channels + z_dim: latent space dimension of the VAE containing the image sytle information + num_channels: number of output after each downsampling block + input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers + of the autoencoder (HxWx[D]) + kernel_size: convolutional kernel size + norm: normalisation layer type + act: activation type + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + z_dim: int, + num_channels: Sequence[int], + input_shape: Sequence[int], + kernel_size: int = 3, + norm: Union[str, tuple] = "INSTANCE", + act: Union[str, tuple] = (Act.LEAKYRELU, {"negative_slope": 0.2}), + ): + + super().__init__() + self.in_channels = in_channels + self.z_dim = z_dim + self.num_channels = num_channels + if len(input_shape) != spatial_dims: + raise ValueError("Length of parameter input shape must match spatial_dims; got %s" % (input_shape)) + for s_ind, s_ in enumerate(input_shape): + if s_ / (2 ** len(num_channels)) != s_ // (2 ** len(num_channels)): + raise ValueError( + "Each dimension of your input must be divisible by 2 ** (autoencoder depth)." + "The shape in position %d, %d is not divisible by %d. " % (s_ind, s_, len(num_channels)) + ) + self.input_shape = input_shape + self.latent_spatial_shape = [s_ // (2 ** len(self.num_channels)) for s_ in self.input_shape] + blocks = [] + ch_init = self.in_channels + for ch_ind, ch_value in enumerate(num_channels): + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=ch_init, + out_channels=ch_value, + strides=2, + kernel_size=kernel_size, + norm=norm, + act=act, + ) + ) + ch_init = ch_value + + self.blocks = nn.ModuleList(blocks) + self.fc_mu = nn.Linear( + in_features=np.prod(self.latent_spatial_shape) * self.num_channels[-1], out_features=self.z_dim + ) + self.fc_var = nn.Linear( + in_features=np.prod(self.latent_spatial_shape) * self.num_channels[-1], out_features=self.z_dim + ) + + def forward(self, x): + for block in self.blocks: + x = block(x) + x = x.view(x.size(0), -1) + mu = self.fc_mu(x) + logvar = self.fc_var(x) + return mu, logvar + + def encode(self, x): + for block in self.blocks: + x = block(x) + x = x.view(x.size(0), -1) + mu = self.fc_mu(x) + logvar = self.fc_var(x) + return self.reparameterize(mu, logvar) + + def reparameterize(self, mu, logvar): + + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return eps.mul(std) + mu + + +class SPADE_Decoder(nn.Module): + """ + Decoder branch of a SPADE-like generator. It can be used independently, without an encoding branch, + behaving like a GAN, or coupled to a SPADE encoder. + + Args: + label_nc: number of semantic labels + spatial_dims: number of spatial dimensions + out_channels: number of output channels + label_nc: number of semantic channels used for the SPADE normalisation blocks + input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers + num_channels: number of output after each downsampling block + z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used) + is_gan: whether the decoder is going to be coupled to an autoencoder or not (true: not, false: yes) + spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks + norm: base normalisation type + act: activation layer type + last_act: activation layer type for the last layer of the network (can differ from previous) + kernel_size: convolutional kernel size + upsampling_mode: upsampling mode (nearest, bilinear etc.) + """ + + def __init__( + self, + spatial_dims: int, + out_channels: int, + label_nc: int, + input_shape: Sequence[int], + num_channels: Sequence[int], + z_dim: Union[int, None] = None, + is_gan: bool = False, + spade_intermediate_channels: int = 128, + norm: Union[str, tuple] = "INSTANCE", + act: Union[str, tuple, None] = (Act.LEAKYRELU, {"negative_slope": 0.2}), + last_act: Union[str, tuple, None] = (Act.LEAKYRELU, {"negative_slope": 0.2}), + kernel_size: int = 3, + upsampling_mode: str = UpsamplingModes.nearest.value, + ): + + super().__init__() + self.is_gan = is_gan + self.out_channels = out_channels + self.label_nc = label_nc + self.num_channels = num_channels + if len(input_shape) != spatial_dims: + raise ValueError("Length of parameter input shape must match spatial_dims; got %s" % (input_shape)) + for s_ind, s_ in enumerate(input_shape): + if s_ / (2 ** len(num_channels)) != s_ // (2 ** len(num_channels)): + raise ValueError( + "Each dimension of your input must be divisible by 2 ** (autoencoder depth)." + "The shape in position %d, %d is not divisible by %d. " % (s_ind, s_, len(num_channels)) + ) + self.latent_spatial_shape = [s_ // (2 ** len(self.num_channels)) for s_ in input_shape] + + if self.is_gan: + self.fc = nn.Linear(label_nc, np.prod(self.latent_spatial_shape) * num_channels[0]) + else: + self.fc = nn.Linear(z_dim, np.prod(self.latent_spatial_shape) * num_channels[0]) + + blocks = [] + num_channels.append(self.out_channels) + self.upsampling = torch.nn.Upsample(scale_factor=2, mode=upsampling_mode) + for ch_ind, ch_value in enumerate(num_channels[:-1]): + blocks.append( + SPADE_ResNetBlock( + spatial_dims=spatial_dims, + in_channels=ch_value, + out_channels=num_channels[ch_ind + 1], + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + norm=norm, + kernel_size=kernel_size, + ) + ) + + self.blocks = torch.nn.ModuleList(blocks) + self.last_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + out_channels=out_channels, + padding=(kernel_size - 1) // 2, + kernel_size=kernel_size, + norm=None, + act=last_act, + ) + + def forward(self, seg, z: torch.Tensor = None): + if self.is_gan: + x = F.interpolate(seg, size=tuple(self.latent_spatial_shape)) + x = self.fc(x) + else: + if z is None: + z = torch.randn(seg.size(0), self.opt.z_dim, dtype=torch.float32, device=seg.get_device()) + x = self.fc(z) + x = x.view(*[-1, self.num_channels[0]] + self.latent_spatial_shape) + + for res_block in self.blocks: + x = res_block(x, seg) + x = self.upsampling(x) + + x = self.last_conv(x) + return x + + +class SPADE_Net(nn.Module): + + """ + SPADE Network, implemented based on the code by Park, T et al. in "Semantic Image Synthesis with Spatially-Adaptive Normalization" + (https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: number of spatial dimensions + in_channels: number of input channels + out_channels: number of output channels + label_nc: number of semantic channels used for the SPADE normalisation blocks + input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers + num_channels: number of output after each downsampling block + z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used) + is_vae: whether the decoder is going to be coupled to an autoencoder (true) or not (false) + spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks + norm: base normalisation type + act: activation layer type + last_act: activation layer type for the last layer of the network (can differ from previous) + kernel_size: convolutional kernel size + upsampling_mode: upsampling mode (nearest, bilinear etc.) + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + label_nc: int, + input_shape: Sequence[int], + num_channels: Sequence[int], + z_dim: Union[int, None] = None, + is_vae: bool = True, + spade_intermediate_channels: int = 128, + norm: Union[str, tuple] = "INSTANCE", + act: Union[str, tuple, None] = (Act.LEAKYRELU, {"negative_slope": 0.2}), + last_act: Union[str, tuple, None] = (Act.LEAKYRELU, {"negative_slope": 0.2}), + kernel_size: int = 3, + upsampling_mode: str = UpsamplingModes.nearest.value, + ): + + super().__init__() + self.is_vae = is_vae + if self.is_vae and z_dim is None: + ValueError("The latent space dimension mapped by parameter z_dim cannot be None is is_vae is True.") + + self.in_channels = in_channels + self.out_channels = out_channels + self.num_channels = num_channels + self.label_nc = label_nc + self.input_shape = input_shape + self.kld_loss = KLDLoss() + + if self.is_vae: + self.encoder = SPADE_Encoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + z_dim=z_dim, + num_channels=num_channels, + input_shape=input_shape, + kernel_size=kernel_size, + norm=norm, + act=act, + ) + + decoder_channels = num_channels + decoder_channels.reverse() + + self.decoder = SPADE_Decoder( + spatial_dims=spatial_dims, + out_channels=out_channels, + label_nc=label_nc, + input_shape=input_shape, + num_channels=decoder_channels, + z_dim=z_dim, + is_gan=not is_vae, + spade_intermediate_channels=spade_intermediate_channels, + norm=norm, + act=act, + last_act=last_act, + kernel_size=kernel_size, + upsampling_mode=upsampling_mode, + ) + + def forward(self, seg: torch.Tensor, x: Union[torch.Tensor, None] = None): + z = None + if self.is_vae: + z_mu, z_logvar = self.encoder(x) + z = self.encoder.reparameterize(z_mu, z_logvar) + kld_loss = self.kld_loss(z_mu, z_logvar) + return self.decoder(seg, z), kld_loss + else: + return (self.decoder(seg, z),) + + def encode(self, x: torch.Tensor): + + return self.encoder.encode(x) + + def decode(self, seg: torch.Tensor, z: Union[torch.Tensor, None] = None): + + return self.decoder(seg, z) diff --git a/tests/test_spade_vaegan.py b/tests/test_spade_vaegan.py new file mode 100644 index 00000000..e030b81e --- /dev/null +++ b/tests/test_spade_vaegan.py @@ -0,0 +1,131 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from monai.networks import eval_mode +from parameterized import parameterized + +from generative.networks.nets.spade_network import SPADE_Net + +CASE_2D = [[[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], 16, True]]] +CASE_2D_BIS = [[[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], 16, True]]] +CASE_3D = [[[3, 1, 1, 3, [64, 64, 64], [16, 32, 64, 128], 16, True]]] + + +def create_semantic_data(shape: list, semantic_regions: int): + """ + To create semantic and image mock inputs for the network. + Args: + shape: input shape + semantic_regions: number of semantic regions + Returns: + """ + out_label = torch.zeros(shape) + out_image = torch.zeros(shape) + torch.randn(shape) * 0.01 + for i in range(1, semantic_regions): + shape_square = [i // np.random.choice(list(range(2, i // 2))) for i in shape] + start_point = [np.random.choice(list(range(shape[ind] - shape_square[ind]))) for ind, i in enumerate(shape)] + if len(shape) == 2: + out_label[ + start_point[0] : (start_point[0] + shape_square[0]), start_point[1] : (start_point[1] + shape_square[1]) + ] = i + base_intensity = torch.ones(shape_square) * np.random.randn() + out_image[ + start_point[0] : (start_point[0] + shape_square[0]), start_point[1] : (start_point[1] + shape_square[1]) + ] = (base_intensity + torch.randn(shape_square) * 0.1) + elif len(shape) == 3: + out_label[ + start_point[0] : (start_point[0] + shape_square[0]), + start_point[1] : (start_point[1] + shape_square[1]), + start_point[2] : (start_point[2] + shape_square[2]), + ] = i + base_intensity = torch.ones(shape_square) * np.random.randn() + out_image[ + start_point[0] : (start_point[0] + shape_square[0]), + start_point[1] : (start_point[1] + shape_square[1]), + start_point[2] : (start_point[2] + shape_square[2]), + ] = (base_intensity + torch.randn(shape_square) * 0.1) + else: + ValueError("Supports only 2D and 3D tensors") + + # One hot encode label + out_label_ = torch.zeros([semantic_regions] + list(out_label.shape)) + for ch in range(semantic_regions): + out_label_[ch, ...] = out_label == ch + + return out_label_.unsqueeze(0), out_image.unsqueeze(0).unsqueeze(0) + + +class TestDiffusionModelUNet2D(unittest.TestCase): + @parameterized.expand(CASE_2D) + def test_forward_2d(self, input_param): + """ + Check that forward method is called correctly and output shape matches. + """ + net = SPADE_Net(*input_param) + in_label, in_image = create_semantic_data(input_param[4], input_param[3]) + with eval_mode(net): + out, kld = net(in_label, in_image) + self.assertEqual( + False, + True in torch.isnan(out) + or True in torch.isinf(out) + or True in torch.isinf(kld) + or True in torch.isinf(kld), + ) + self.assertEqual(list(out.shape), [1, 1, 64, 64]) + + @parameterized.expand(CASE_2D_BIS) + def test_encoder_decoder(self, input_param): + """ + Check that forward method is called correctly and output shape matches. + """ + net = SPADE_Net(*input_param) + in_label, in_image = create_semantic_data(input_param[4], input_param[3]) + with eval_mode(net): + out_z = net.encode(in_image) + self.assertEqual(list(out_z.shape), [1, 16]) + out_i = net.decode(in_label, out_z) + self.assertEqual(list(out_i.shape), [1, 1, 64, 64]) + + @parameterized.expand(CASE_3D) + def test_forward_3d(self, input_param): + """ + Check that forward method is called correctly and output shape matches. + """ + net = SPADE_Net(*input_param) + in_label, in_image = create_semantic_data(input_param[4], input_param[3]) + with eval_mode(net): + out, kld = net(in_label, in_image) + self.assertEqual( + False, + True in torch.isnan(out) + or True in torch.isinf(out) + or True in torch.isinf(kld) + or True in torch.isinf(kld), + ) + self.assertEqual(list(out.shape), [1, 1, 64, 64, 64]) + + def test_shape_wrong(self): + """ + We input an input shape that isn't divisible by 2**(n downstream steps) + """ + with self.assertRaises(ValueError): + net = SPADE_Net(1, 1, 8, [16, 16], [16, 32, 64, 128], 16, True) + + +if __name__ == "__main__": + unittest.main() diff --git a/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb b/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb index 11d059c3..2b398ee5 100644 --- a/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb +++ b/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb @@ -72,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "id": "350736c2", "metadata": {}, "outputs": [ @@ -80,29 +80,30 @@ "name": "stdout", "output_type": "stream", "text": [ - "MONAI version: 1.1.dev2239\n", - "Numpy version: 1.23.3\n", - "Pytorch version: 1.8.0+cu111\n", + "MONAI version: 1.2.dev2304\n", + "Numpy version: 1.23.5\n", + "Pytorch version: 1.13.1+cu117\n", "MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n", - "MONAI rev id: 13b24fa92b9d98bd0dc6d5cdcb52504fd09e297b\n", - "MONAI __file__: /media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.8/site-packages/monai/__init__.py\n", + "MONAI rev id: 9a57be5aab9f2c2a134768c0c146399150e247a0\n", + "MONAI __file__: /home/vf19/PycharmProjects/GenerativeModels/venv/lib/python3.9/site-packages/monai/__init__.py\n", "\n", "Optional dependencies:\n", "Pytorch Ignite version: 0.4.10\n", - "Nibabel version: 4.0.2\n", - "scikit-image version: NOT INSTALLED or UNKNOWN VERSION.\n", - "Pillow version: 9.2.0\n", - "Tensorboard version: 2.11.0\n", - "gdown version: NOT INSTALLED or UNKNOWN VERSION.\n", - "TorchVision version: 0.9.0+cu111\n", + "ITK version: 5.3.0\n", + "Nibabel version: 5.0.0\n", + "scikit-image version: 0.19.3\n", + "Pillow version: 9.4.0\n", + "Tensorboard version: 2.12.0\n", + "gdown version: 4.6.3\n", + "TorchVision version: 0.14.1+cu117\n", "tqdm version: 4.64.1\n", - "lmdb version: NOT INSTALLED or UNKNOWN VERSION.\n", - "psutil version: 5.9.3\n", - "pandas version: NOT INSTALLED or UNKNOWN VERSION.\n", + "lmdb version: 1.4.0\n", + "psutil version: 5.9.4\n", + "pandas version: 1.5.3\n", "einops version: 0.6.0\n", - "transformers version: NOT INSTALLED or UNKNOWN VERSION.\n", - "mlflow version: NOT INSTALLED or UNKNOWN VERSION.\n", - "pynrrd version: NOT INSTALLED or UNKNOWN VERSION.\n", + "transformers version: 4.21.3\n", + "mlflow version: 2.1.1\n", + "pynrrd version: 1.0.0\n", "\n", "For details about installing the optional dependencies, please visit:\n", " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n", @@ -115,7 +116,7 @@ "import shutil\n", "import tempfile\n", "import time\n", - "\n", + "from pathlib import Path\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", @@ -1229,7 +1230,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.9.16" } }, "nbformat": 4, diff --git a/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.py b/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.py index 53ccf898..39e44730 100644 --- a/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.py +++ b/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.py @@ -42,7 +42,6 @@ import shutil import tempfile import time - import matplotlib.pyplot as plt import numpy as np import torch diff --git a/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb b/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb new file mode 100644 index 00000000..125ecd8b --- /dev/null +++ b/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb @@ -0,0 +1,2405 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "1f7ba8ce", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) MONAI Consortium\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "id": "102909fb", + "metadata": {}, + "source": [ + "# SPADE VAE-GAN" + ] + }, + { + "cell_type": "markdown", + "id": "7d4cbc3c", + "metadata": {}, + "source": [ + "This notebook creates a mock SPADE VAE-GAN based on the paper \"Semantic Image Synthesis with Spatially-Adaptive Normalization\" (2019) by Park T, Liu MY, Wang TC, Zhu JY. More information available at: https://github.com/NVlabs/SPADE" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "e059c423", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "import os\n", + "import tempfile\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "from pathlib import Path\n", + "import zipfile\n", + "import gdown\n", + "from monai.data import DataLoader\n", + "from tqdm import tqdm\n", + "from generative.losses import PatchAdversarialLoss, PerceptualLoss\n", + "from generative.networks.nets import MultiScalePatchDiscriminator\n", + "import numpy as np\n", + "import monai\n", + "from generative.networks.nets.spade_network import SPADE_Net" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "e76296e7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Temporary directory used: /tmp/tmpz0rbc_3s \n" + ] + } + ], + "source": [ + "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n", + "root_dir = tempfile.mkdtemp() if directory is None else directory\n", + "root_dir = Path(root_dir)\n", + "print(\"Temporary directory used: %s \" % root_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "2483148a", + "metadata": {}, + "outputs": [], + "source": [ + "# INPUT PARAMETERS\n", + "input_shape = [128, 128]\n", + "batch_size = 6\n", + "num_workers = 4\n", + "num_epochs = 100\n", + "lambda_perc = 1.0\n", + "lambda_feat = 0.1\n", + "lambda_kld = 0.00001\n", + "loss_adv = 1.0" + ] + }, + { + "cell_type": "markdown", + "id": "3f2448ba", + "metadata": {}, + "source": [ + "### Data" + ] + }, + { + "cell_type": "markdown", + "id": "2c91eec2", + "metadata": {}, + "source": [ + "The data for this notebook comes from the public dataset OASIS (Open Access Series of Imaging Studies) [1]. The images have been registered to MNI space using ANTsPy, and then subsampled to 2mm isotropic resolution. Geodesic Information Flows (GIF) [2] has been used to segment 5 regions: cerebrospinal fluid (CSF), grey matter (GM), white matter (WM), deep grey matter (DGM) and brainstem. In addition, BaMos [3] has been used to provide white matter hyperintensities segmentations (WMH). The available dataset contains:\n", + "- T1-weighted images\n", + "- FLAIR weighted images\n", + "- Segmentations with the following labels: 0 (background), 1 (CSF), 2 (GM), 3 (WM), 4 (DGM), 5 (brainstem) and 6 (WMH).\n", + "\n", + "_**Acknowledgments**: \"Data were provided by OASIS-3: Longitudinal Multimodal Neuroimaging: Principal Investigators: T. Benzinger, D. Marcus, J. Morris; NIH P30 AG066444, P50 AG00561, P30 NS09857781, P01 AG026276, P01 AG003991, R01 AG043434, UL1 TR000448, R01 EB009352. AV-45 doses were provided by Avid Radiopharmaceuticals, a wholly owned subsidiary of Eli Lilly.”_\n", + "\n", + "\n", + "Citations:\n", + "\n", + "[1] Marcus, DS, Wang, TH, Parker, J, Csernansky, JG, Morris, JC, Buckner. Open Access Series of Imaging Studies (OASIS): Cross-Sectional MRI Data in Young, Middle Aged, Nondemented, and Demented Older Adults, RL. Journal of Cognitive Neuroscience, 19, 1498-1507. doi: 10.1162/jocn.2007.19.9.1498\n", + "\n", + "[2] Cardoso MJ, Modat M, Wolz R, Melbourne A, Cash D, Rueckert D, Ourselin S. Geodesic Information Flows: Spatially-Variant Graphs and Their Application to Segmentation and Fusion. IEEE Trans Med Imaging. 2015 Sep;34(9):1976-88. doi: 10.1109/TMI.2015.2418298. Epub 2015 Apr 14. PMID: 25879909.\n", + "\n", + "[3] Fiford CM, Sudre CH, Pemberton H, Walsh P, Manning E, Malone IB, Nicholas J, Bouvy WH, Carmichael OT, Biessels GJ, Cardoso MJ, Barnes J; Alzheimer’s Disease Neuroimaging Initiative. Automated White Matter Hyperintensity Segmentation Using Bayesian Model Selection: Assessment and Correlations with Cognitive Change. Neuroinformatics. 2020 Jun;18(3):429-449. doi: 10.1007/s12021-019-09439-6. PMID: 32062817; PMCID: PMC7338814.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "dc560f7e", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Downloading...\n", + "From: https://drive.google.com/uc?export=download&id=1SX_MCzQe-vyq09QYxECk32wZ2vxp9rx5\n", + "To: /tmp/tmpz0rbc_3s/data.zip\n", + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 384M/384M [00:05<00:00, 67.0MB/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "'/tmp/tmpz0rbc_3s/data.zip'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gdown.download(\n", + " \"https://drive.google.com/uc?export=download&id=1SX_MCzQe-vyq09QYxECk32wZ2vxp9rx5\", str(root_dir / \"data.zip\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "cd7dd6ec", + "metadata": {}, + "outputs": [], + "source": [ + "zip_obj = zipfile.ZipFile(os.path.join(root_dir, \"data.zip\"), \"r\")\n", + "zip_obj.extractall(root_dir)\n", + "images_T1 = root_dir / \"OASIS_SMALL-SUBSET/T1\"\n", + "images_FLAIR = root_dir / \"OASIS_SMALL-SUBSET/FLAIR\"\n", + "labels = root_dir / \"OASIS_SMALL-SUBSET/Segmentations\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "d48987b9", + "metadata": {}, + "outputs": [], + "source": [ + "# We create the data dictionaries that we need\n", + "all_images = [os.path.join(images_T1, i) for i in os.listdir(images_T1)] + [\n", + " os.path.join(images_FLAIR, i) for i in os.listdir(images_FLAIR)\n", + "]\n", + "np.random.shuffle(all_images)\n", + "corresponding_labels = [\n", + " os.path.join(labels, i.split(\"/\")[-1].replace(i.split(\"/\")[-1].split(\"_\")[0], \"Parcellation\")) for i in all_images\n", + "]\n", + "input_dict = [{\"image\": i, \"label\": corresponding_labels[ind]} for ind, i in enumerate(all_images)]\n", + "input_dict_train = input_dict[: int(len(input_dict) * 0.9)]\n", + "input_dict_val = input_dict[int(len(input_dict) * 0.9) :]" + ] + }, + { + "cell_type": "markdown", + "id": "9916ca5a", + "metadata": {}, + "source": [ + "### Dataloaders" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "8ab79cdc", + "metadata": { + "lines_to_end_of_cell_marker": 2 + }, + "outputs": [], + "source": [ + "preliminar_shape = input_shape + [50] # We take random slices fron the center of the brain\n", + "crop_shape = input_shape + [1]\n", + "base_transforms = [\n", + " monai.transforms.LoadImaged(keys=[\"label\", \"image\"]),\n", + " monai.transforms.EnsureChannelFirstd(keys=[\"image\", \"label\"]),\n", + " monai.transforms.CenterSpatialCropd(keys=[\"label\", \"image\"], roi_size=preliminar_shape),\n", + " monai.transforms.RandSpatialCropd(keys=[\"label\", \"image\"], roi_size=crop_shape, max_roi_size=crop_shape),\n", + " monai.transforms.SqueezeDimd(keys=[\"label\", \"image\"], dim=-1),\n", + " monai.transforms.Resized(keys=[\"image\", \"label\"], spatial_size=input_shape),\n", + "]\n", + "last_transforms = [\n", + " monai.transforms.CopyItemsd(keys=[\"label\"], names=[\"label_channel\"]),\n", + " monai.transforms.Lambdad(keys=[\"label_channel\"], func=lambda l: l != 0),\n", + " monai.transforms.MaskIntensityd(keys=[\"image\"], mask_key=\"label_channel\"),\n", + " monai.transforms.NormalizeIntensityd(keys=[\"image\"]),\n", + " monai.transforms.ToTensord(keys=[\"image\", \"label\"]),\n", + "]\n", + "\n", + "aug_transforms = [\n", + " monai.transforms.RandBiasFieldd(coeff_range=(0, 0.005), prob=0.33, keys=[\"image\"]),\n", + " monai.transforms.RandAdjustContrastd(gamma=(0.9, 1.15), prob=0.33, keys=[\"image\"]),\n", + " monai.transforms.RandGaussianNoised(prob=0.33, mean=0.0, std=np.random.uniform(0.005, 0.015), keys=[\"image\"]),\n", + " monai.transforms.RandAffined(\n", + " rotate_range=[-0.05, 0.05],\n", + " shear_range=[0.001, 0.05],\n", + " scale_range=[0, 0.05],\n", + " padding_mode=\"zeros\",\n", + " mode=\"nearest\",\n", + " prob=0.33,\n", + " keys=[\"label\", \"image\"],\n", + " ),\n", + "]\n", + "\n", + "train_transforms = monai.transforms.Compose(base_transforms + aug_transforms + last_transforms)\n", + "val_transforms = monai.transforms.Compose(base_transforms + last_transforms)\n", + "\n", + "train_dataset = monai.data.dataset.Dataset(input_dict_train, train_transforms)\n", + "val_dataset = monai.data.dataset.Dataset(input_dict_val, val_transforms)\n", + "train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)\n", + "val_loader = DataLoader(val_dataset, shuffle=False, drop_last=False, batch_size=batch_size, num_workers=num_workers)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "98d14e75", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([6, 1, 128, 128])\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Sanity check\n", + "batch = next(iter(train_loader))\n", + "print(batch[\"image\"].shape)\n", + "plt.subplot(1, 2, 1)\n", + "plt.imshow(batch[\"image\"][0, 0, ...], cmap=\"gist_gray\")\n", + "plt.axis(\"off\")\n", + "plt.subplot(1, 2, 2)\n", + "plt.imshow(batch[\"label\"][0, 0, ...], cmap=\"jet\")\n", + "plt.axis(\"off\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "63de4490", + "metadata": {}, + "source": [ + "### Network creation and losses" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "fa17d864", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "8f126b17", + "metadata": {}, + "outputs": [], + "source": [ + "def one_hot(input_label, label_nc):\n", + " # One hot encoding function for the labels\n", + " shape_ = list(input_label.shape)\n", + " shape_[1] = label_nc\n", + " label_out = torch.zeros(shape_)\n", + " for channel in range(label_nc):\n", + " label_out[:, channel, ...] = input_label[:, 0, ...] == channel\n", + " return label_out" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "255c90c7", + "metadata": {}, + "outputs": [], + "source": [ + "def picture_results(input_label, input_image, output_image):\n", + " f = plt.figure(figsize=(4, 1.5))\n", + " plt.subplot(1, 3, 1)\n", + " plt.imshow(torch.argmax(input_label, 1)[0, ...].detach().cpu(), cmap=\"jet\")\n", + " plt.axis(\"off\")\n", + " plt.title(\"Label\")\n", + " plt.subplot(1, 3, 2)\n", + " plt.imshow(input_image[0, 0, ...].detach().cpu(), cmap=\"gist_gray\")\n", + " plt.axis(\"off\")\n", + " plt.title(\"Input image\")\n", + " plt.subplot(1, 3, 3)\n", + " plt.imshow(output_image[0, 0, ...].detach().cpu(), cmap=\"gist_gray\")\n", + " plt.axis(\"off\")\n", + " plt.title(\"Output image\")\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "c18dbad8", + "metadata": {}, + "outputs": [], + "source": [ + "def feature_loss(input_features_disc_fake, input_features_disc_real, lambda_feat, device):\n", + " criterion = torch.nn.L1Loss()\n", + " num_D = len(input_features_disc_fake)\n", + " GAN_Feat_loss = torch.zeros(1).to(device)\n", + " for i in range(num_D): # for each discriminator\n", + " num_intermediate_outputs = len(input_features_disc_fake[i])\n", + " for j in range(num_intermediate_outputs): # for each layer output\n", + " unweighted_loss = criterion(input_features_disc_fake[i][j], input_features_disc_real[i][j].detach())\n", + " GAN_Feat_loss += unweighted_loss * lambda_feat / num_D\n", + " return GAN_Feat_loss" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "89989c34", + "metadata": {}, + "outputs": [], + "source": [ + "net = SPADE_Net(\n", + " spatial_dims=2,\n", + " in_channels=1,\n", + " out_channels=1,\n", + " label_nc=6,\n", + " input_shape=input_shape,\n", + " num_channels=[16, 32, 64, 128],\n", + " z_dim=16,\n", + " is_vae=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "5b8b676f", + "metadata": {}, + "outputs": [], + "source": [ + "discriminator = MultiScalePatchDiscriminator(\n", + " num_d=2,\n", + " num_layers_d=3,\n", + " spatial_dims=2,\n", + " num_channels=8,\n", + " in_channels=7,\n", + " out_channels=7,\n", + " minimum_size_im=128,\n", + " norm=\"INSTANCE\",\n", + " kernel_size=3,\n", + ")\n", + "\n", + "adversarial_loss = PatchAdversarialLoss(reduction=\"sum\", criterion=\"hinge\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "36ea4308", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", + "Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.\n" + ] + } + ], + "source": [ + "perceptual_loss = PerceptualLoss(spatial_dims=2, network_type=\"vgg\", is_fake_3d=False, pretrained=True)\n", + "perceptual_loss = perceptual_loss.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "3b57abd3", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer_G = torch.optim.Adam(net.parameters(), lr=0.0002)\n", + "optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0004)" + ] + }, + { + "cell_type": "markdown", + "id": "b8fde71b", + "metadata": {}, + "source": [ + "### Training loop\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "918eac0a", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████| 15/15 [00:06<00:00, 2.19it/s, kld=135, perceptual=0.221, generator=3.24, feature=0.113, discriminator=2.27]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████| 2/2 [00:00<00:00, 3.22it/s, kld=63.6, perceptual=0.254, generator=3.38, feature=0.125, discriminator=1.77]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.26it/s, kld=754, perceptual=0.354, generator=3.46, feature=0.168, discriminator=1.79]\n", + "100%|████████| 2/2 [00:00<00:00, 4.36it/s, kld=359, perceptual=0.36, generator=3.44, feature=0.173, discriminator=1.85]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.17it/s, kld=298, perceptual=0.324, generator=3.57, feature=0.167, discriminator=1.58]\n", + "100%|████████| 2/2 [00:00<00:00, 4.07it/s, kld=346, perceptual=0.34, generator=3.57, feature=0.167, discriminator=1.63]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|███████| 15/15 [00:06<00:00, 2.27it/s, kld=313, perceptual=0.32, generator=3.57, feature=0.166, discriminator=1.6]\n", + "100%|███████| 2/2 [00:00<00:00, 4.36it/s, kld=262, perceptual=0.325, generator=3.49, feature=0.164, discriminator=1.63]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.28it/s, kld=213, perceptual=0.305, generator=3.52, feature=0.168, discriminator=1.67]\n", + "100%|███████| 2/2 [00:00<00:00, 4.45it/s, kld=101, perceptual=0.307, generator=3.58, feature=0.164, discriminator=1.65]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.23it/s, kld=119, perceptual=0.302, generator=3.6, feature=0.165, discriminator=1.51]\n", + "100%|██████| 2/2 [00:00<00:00, 3.79it/s, kld=68.9, perceptual=0.288, generator=3.58, feature=0.165, discriminator=1.61]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.21it/s, kld=114, perceptual=0.286, generator=3.59, feature=0.163, discriminator=1.5]\n", + "100%|████████| 2/2 [00:00<00:00, 4.25it/s, kld=80, perceptual=0.278, generator=3.59, feature=0.161, discriminator=1.51]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████████| 15/15 [00:06<00:00, 2.28it/s, kld=116, perceptual=0.263, generator=3.6, feature=0.16, discriminator=1.5]\n", + "100%|█████████| 2/2 [00:00<00:00, 4.44it/s, kld=73.9, perceptual=0.269, generator=3.6, feature=0.16, discriminator=1.5]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|███████| 15/15 [00:06<00:00, 2.22it/s, kld=100, perceptual=0.25, generator=3.6, feature=0.155, discriminator=1.53]\n", + "100%|█████████| 2/2 [00:00<00:00, 4.31it/s, kld=54.1, perceptual=0.26, generator=3.6, feature=0.154, discriminator=1.5]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.19it/s, kld=97.2, perceptual=0.243, generator=3.54, feature=0.149, discriminator=1.57]\n", + "100%|██████| 2/2 [00:00<00:00, 4.01it/s, kld=69.7, perceptual=0.252, generator=3.62, feature=0.141, discriminator=1.46]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=106, perceptual=0.242, generator=3.52, feature=0.129, discriminator=1.72]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████| 2/2 [00:00<00:00, 3.42it/s, kld=109, perceptual=0.258, generator=3.54, feature=0.128, discriminator=1.6]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=413, perceptual=0.325, generator=3.21, feature=0.144, discriminator=2.18]\n", + "100%|███████| 2/2 [00:00<00:00, 3.80it/s, kld=252, perceptual=0.325, generator=3.39, feature=0.141, discriminator=1.88]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.26it/s, kld=238, perceptual=0.311, generator=3.55, feature=0.169, discriminator=1.67]\n", + "100%|████████| 2/2 [00:00<00:00, 4.29it/s, kld=214, perceptual=0.317, generator=3.57, feature=0.17, discriminator=1.61]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.25it/s, kld=110, perceptual=0.304, generator=3.57, feature=0.177, discriminator=1.55]\n", + "100%|███████| 2/2 [00:00<00:00, 4.01it/s, kld=53.1, perceptual=0.314, generator=3.58, feature=0.176, discriminator=1.5]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████████| 15/15 [00:06<00:00, 2.16it/s, kld=129, perceptual=0.27, generator=3.6, feature=0.17, discriminator=1.52]\n", + "100%|████████| 2/2 [00:00<00:00, 4.45it/s, kld=83.5, perceptual=0.295, generator=3.6, feature=0.17, discriminator=1.47]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.29it/s, kld=74.5, perceptual=0.254, generator=3.59, feature=0.166, discriminator=1.52]\n", + "100%|███████| 2/2 [00:00<00:00, 4.49it/s, kld=38.9, perceptual=0.285, generator=3.6, feature=0.166, discriminator=1.47]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.29it/s, kld=103, perceptual=0.25, generator=3.63, feature=0.162, discriminator=1.45]\n", + "100%|███████| 2/2 [00:00<00:00, 4.36it/s, kld=80.4, perceptual=0.246, generator=3.6, feature=0.159, discriminator=1.47]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.28it/s, kld=95.8, perceptual=0.247, generator=3.62, feature=0.162, discriminator=1.44]\n", + "100%|████████| 2/2 [00:00<00:00, 4.44it/s, kld=60, perceptual=0.233, generator=3.51, feature=0.166, discriminator=1.59]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.33it/s, kld=476, perceptual=0.248, generator=3.36, feature=0.152, discriminator=1.84]\n", + "100%|███████| 2/2 [00:00<00:00, 2.86it/s, kld=181, perceptual=0.251, generator=3.59, feature=0.172, discriminator=1.53]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.23it/s, kld=478, perceptual=0.285, generator=3.52, feature=0.15, discriminator=1.87]\n", + "100%|██████| 2/2 [00:00<00:00, 4.42it/s, kld=483, perceptual=0.304, generator=2.73, feature=0.0971, discriminator=2.46]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.20it/s, kld=725, perceptual=0.348, generator=3.31, feature=0.157, discriminator=1.78]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████| 2/2 [00:00<00:00, 2.95it/s, kld=530, perceptual=0.346, generator=3.16, feature=0.143, discriminator=1.94]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.24it/s, kld=143, perceptual=0.382, generator=3.6, feature=0.218, discriminator=1.46]\n", + "100%|██████| 2/2 [00:00<00:00, 4.35it/s, kld=80.1, perceptual=0.399, generator=3.61, feature=0.225, discriminator=1.44]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.25it/s, kld=179, perceptual=0.344, generator=3.59, feature=0.209, discriminator=1.51]\n", + "100%|██████| 2/2 [00:00<00:00, 4.43it/s, kld=91.7, perceptual=0.366, generator=3.62, feature=0.217, discriminator=1.38]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.21it/s, kld=336, perceptual=0.31, generator=3.59, feature=0.182, discriminator=1.49]\n", + "100%|████████| 2/2 [00:00<00:00, 4.46it/s, kld=216, perceptual=0.311, generator=3.6, feature=0.185, discriminator=1.48]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.23it/s, kld=423, perceptual=0.271, generator=3.58, feature=0.157, discriminator=1.53]\n", + "100%|███████| 2/2 [00:00<00:00, 4.07it/s, kld=136, perceptual=0.269, generator=3.59, feature=0.168, discriminator=1.52]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.26it/s, kld=349, perceptual=0.276, generator=3.6, feature=0.155, discriminator=1.45]\n", + "100%|███████| 2/2 [00:00<00:00, 4.41it/s, kld=220, perceptual=0.272, generator=3.61, feature=0.156, discriminator=1.45]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.31it/s, kld=109, perceptual=0.265, generator=3.62, feature=0.156, discriminator=1.46]\n", + "100%|███████| 2/2 [00:00<00:00, 4.00it/s, kld=117, perceptual=0.246, generator=3.59, feature=0.157, discriminator=1.52]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.27it/s, kld=209, perceptual=0.235, generator=3.61, feature=0.154, discriminator=1.47]\n", + "100%|███████| 2/2 [00:00<00:00, 4.47it/s, kld=120, perceptual=0.256, generator=3.61, feature=0.156, discriminator=1.43]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.33it/s, kld=146, perceptual=0.231, generator=3.61, feature=0.15, discriminator=1.44]\n", + "100%|██████| 2/2 [00:00<00:00, 4.35it/s, kld=76.3, perceptual=0.248, generator=3.61, feature=0.153, discriminator=1.41]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.28it/s, kld=180, perceptual=0.225, generator=3.59, feature=0.147, discriminator=1.45]\n", + "100%|███████| 2/2 [00:00<00:00, 4.18it/s, kld=98.1, perceptual=0.228, generator=3.6, feature=0.149, discriminator=1.43]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 30/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.31it/s, kld=113, perceptual=0.233, generator=3.63, feature=0.145, discriminator=1.37]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████| 2/2 [00:00<00:00, 3.31it/s, kld=90, perceptual=0.226, generator=3.6, feature=0.142, discriminator=1.43]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 31/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.21it/s, kld=114, perceptual=0.227, generator=3.63, feature=0.145, discriminator=1.4]\n", + "100%|██████| 2/2 [00:00<00:00, 4.38it/s, kld=62.7, perceptual=0.212, generator=3.64, feature=0.144, discriminator=1.32]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 32/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=102, perceptual=0.229, generator=3.62, feature=0.149, discriminator=1.36]\n", + "100%|██████| 2/2 [00:00<00:00, 4.42it/s, kld=89.2, perceptual=0.233, generator=3.63, feature=0.143, discriminator=1.36]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 33/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|███████| 15/15 [00:06<00:00, 2.23it/s, kld=155, perceptual=0.2, generator=3.61, feature=0.144, discriminator=1.38]\n", + "100%|██████| 2/2 [00:00<00:00, 4.19it/s, kld=70.7, perceptual=0.226, generator=3.61, feature=0.144, discriminator=1.41]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 34/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.24it/s, kld=118, perceptual=0.217, generator=3.6, feature=0.144, discriminator=1.38]\n", + "100%|██████| 2/2 [00:00<00:00, 4.08it/s, kld=77.5, perceptual=0.204, generator=3.62, feature=0.145, discriminator=1.37]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 35/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.21it/s, kld=128, perceptual=0.199, generator=3.49, feature=0.138, discriminator=1.54]\n", + "100%|██████| 2/2 [00:00<00:00, 4.23it/s, kld=72.3, perceptual=0.219, generator=3.58, feature=0.146, discriminator=1.42]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 36/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.29it/s, kld=137, perceptual=0.206, generator=3.62, feature=0.144, discriminator=1.37]\n", + "100%|██████| 2/2 [00:00<00:00, 4.47it/s, kld=98.9, perceptual=0.222, generator=3.63, feature=0.145, discriminator=1.42]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 37/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.18it/s, kld=250, perceptual=0.227, generator=3.61, feature=0.148, discriminator=1.47]\n", + "100%|███████| 2/2 [00:00<00:00, 4.23it/s, kld=142, perceptual=0.217, generator=3.58, feature=0.148, discriminator=1.45]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 38/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.32it/s, kld=101, perceptual=0.208, generator=3.62, feature=0.145, discriminator=1.36]\n", + "100%|██████| 2/2 [00:00<00:00, 4.44it/s, kld=66.2, perceptual=0.211, generator=3.63, feature=0.145, discriminator=1.35]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 39/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.25it/s, kld=99.7, perceptual=0.209, generator=3.61, feature=0.146, discriminator=1.37]\n", + "100%|██████| 2/2 [00:00<00:00, 3.83it/s, kld=74.2, perceptual=0.202, generator=3.62, feature=0.143, discriminator=1.38]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 40/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.25it/s, kld=132, perceptual=0.206, generator=3.6, feature=0.141, discriminator=1.37]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████| 2/2 [00:00<00:00, 3.33it/s, kld=86.4, perceptual=0.205, generator=3.63, feature=0.142, discriminator=1.36]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 41/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.22it/s, kld=87.2, perceptual=0.208, generator=3.64, feature=0.138, discriminator=1.3]\n", + "100%|██████| 2/2 [00:00<00:00, 3.98it/s, kld=73.5, perceptual=0.197, generator=3.62, feature=0.141, discriminator=1.35]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 42/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.25it/s, kld=120, perceptual=0.219, generator=3.62, feature=0.148, discriminator=1.34]\n", + "100%|██████| 2/2 [00:00<00:00, 4.48it/s, kld=55.8, perceptual=0.207, generator=3.63, feature=0.136, discriminator=1.33]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 43/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.27it/s, kld=95.9, perceptual=0.198, generator=3.64, feature=0.135, discriminator=1.3]\n", + "100%|██████| 2/2 [00:00<00:00, 4.34it/s, kld=57.2, perceptual=0.209, generator=3.66, feature=0.135, discriminator=1.25]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 44/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=128, perceptual=0.197, generator=3.62, feature=0.132, discriminator=1.34]\n", + "100%|██████| 2/2 [00:00<00:00, 3.98it/s, kld=82.2, perceptual=0.182, generator=3.62, feature=0.124, discriminator=1.35]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 45/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.19it/s, kld=108, perceptual=0.184, generator=3.63, feature=0.126, discriminator=1.32]\n", + "100%|██████| 2/2 [00:00<00:00, 4.10it/s, kld=72.3, perceptual=0.195, generator=3.64, feature=0.132, discriminator=1.29]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 46/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=115, perceptual=0.188, generator=3.62, feature=0.126, discriminator=1.32]\n", + "100%|██████| 2/2 [00:00<00:00, 4.20it/s, kld=60.3, perceptual=0.213, generator=3.65, feature=0.129, discriminator=1.27]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 47/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.18it/s, kld=105, perceptual=0.206, generator=3.64, feature=0.127, discriminator=1.28]\n", + "100%|██████| 2/2 [00:00<00:00, 4.34it/s, kld=59.4, perceptual=0.202, generator=3.66, feature=0.125, discriminator=1.26]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 48/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=106, perceptual=0.176, generator=3.63, feature=0.128, discriminator=1.32]\n", + "100%|███████| 2/2 [00:00<00:00, 4.31it/s, kld=55.1, perceptual=0.21, generator=3.64, feature=0.134, discriminator=1.29]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 49/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.26it/s, kld=88.5, perceptual=0.192, generator=3.62, feature=0.131, discriminator=1.33]\n", + "100%|████████| 2/2 [00:00<00:00, 4.29it/s, kld=63.5, perceptual=0.2, generator=3.61, feature=0.131, discriminator=1.35]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 50/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.33it/s, kld=94.5, perceptual=0.2, generator=3.64, feature=0.129, discriminator=1.31]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████| 2/2 [00:00<00:00, 3.45it/s, kld=52.4, perceptual=0.203, generator=3.65, feature=0.133, discriminator=1.27]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 51/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.27it/s, kld=87.7, perceptual=0.192, generator=3.62, feature=0.131, discriminator=1.34]\n", + "100%|███████| 2/2 [00:00<00:00, 4.27it/s, kld=50.3, perceptual=0.207, generator=3.64, feature=0.133, discriminator=1.3]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 52/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.30it/s, kld=91.9, perceptual=0.184, generator=3.62, feature=0.127, discriminator=1.33]\n", + "100%|██████| 2/2 [00:00<00:00, 4.32it/s, kld=65.7, perceptual=0.182, generator=3.62, feature=0.128, discriminator=1.35]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 53/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.29it/s, kld=84.6, perceptual=0.193, generator=3.62, feature=0.136, discriminator=1.34]\n", + "100%|██████| 2/2 [00:00<00:00, 4.45it/s, kld=67.2, perceptual=0.176, generator=3.63, feature=0.129, discriminator=1.32]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 54/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████████| 15/15 [00:06<00:00, 2.31it/s, kld=91.9, perceptual=0.2, generator=3.64, feature=0.14, discriminator=1.3]\n", + "100%|██████| 2/2 [00:00<00:00, 4.49it/s, kld=43.8, perceptual=0.204, generator=3.64, feature=0.133, discriminator=1.28]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 55/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.29it/s, kld=118, perceptual=0.183, generator=3.64, feature=0.129, discriminator=1.28]\n", + "100%|███████| 2/2 [00:00<00:00, 4.44it/s, kld=71.7, perceptual=0.195, generator=3.6, feature=0.141, discriminator=1.37]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 56/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.35it/s, kld=102, perceptual=0.202, generator=3.63, feature=0.137, discriminator=1.3]\n", + "100%|██████| 2/2 [00:00<00:00, 4.18it/s, kld=64.4, perceptual=0.188, generator=3.62, feature=0.133, discriminator=1.33]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 57/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.28it/s, kld=96.8, perceptual=0.181, generator=3.61, feature=0.125, discriminator=1.36]\n", + "100%|██████| 2/2 [00:00<00:00, 4.54it/s, kld=54.1, perceptual=0.201, generator=3.59, feature=0.136, discriminator=1.36]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 58/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.33it/s, kld=137, perceptual=0.195, generator=3.62, feature=0.134, discriminator=1.35]\n", + "100%|██████| 2/2 [00:00<00:00, 4.09it/s, kld=64.9, perceptual=0.198, generator=3.63, feature=0.133, discriminator=1.32]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 59/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.33it/s, kld=77.3, perceptual=0.189, generator=3.64, feature=0.13, discriminator=1.29]\n", + "100%|██████| 2/2 [00:00<00:00, 3.47it/s, kld=48.4, perceptual=0.187, generator=3.65, feature=0.132, discriminator=1.26]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 60/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:07<00:00, 2.11it/s, kld=76.1, perceptual=0.179, generator=3.62, feature=0.132, discriminator=1.31]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████| 2/2 [00:00<00:00, 3.59it/s, kld=45.9, perceptual=0.196, generator=3.65, feature=0.133, discriminator=1.26]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 61/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.22it/s, kld=91.5, perceptual=0.192, generator=3.64, feature=0.134, discriminator=1.3]\n", + "100%|██████| 2/2 [00:00<00:00, 4.30it/s, kld=51.4, perceptual=0.175, generator=3.63, feature=0.127, discriminator=1.29]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 62/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.24it/s, kld=89.7, perceptual=0.199, generator=3.65, feature=0.132, discriminator=1.25]\n", + "100%|██████| 2/2 [00:00<00:00, 4.25it/s, kld=51.4, perceptual=0.179, generator=3.59, feature=0.133, discriminator=1.36]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 63/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.22it/s, kld=85.4, perceptual=0.18, generator=3.61, feature=0.13, discriminator=1.34]\n", + "100%|██████| 2/2 [00:00<00:00, 3.02it/s, kld=69.2, perceptual=0.195, generator=3.62, feature=0.135, discriminator=1.33]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 64/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.30it/s, kld=113, perceptual=0.185, generator=3.64, feature=0.126, discriminator=1.3]\n", + "100%|██████| 2/2 [00:00<00:00, 4.46it/s, kld=67.6, perceptual=0.206, generator=3.65, feature=0.128, discriminator=1.27]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 65/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.27it/s, kld=79, perceptual=0.163, generator=3.61, feature=0.124, discriminator=1.33]\n", + "100%|██████| 2/2 [00:00<00:00, 3.91it/s, kld=47.8, perceptual=0.198, generator=3.63, feature=0.133, discriminator=1.28]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 66/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:07<00:00, 2.13it/s, kld=80.8, perceptual=0.174, generator=3.63, feature=0.13, discriminator=1.3]\n", + "100%|███████| 2/2 [00:00<00:00, 4.20it/s, kld=62.5, perceptual=0.172, generator=3.63, feature=0.134, discriminator=1.3]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 67/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=88.8, perceptual=0.208, generator=3.64, feature=0.135, discriminator=1.28]\n", + "100%|██████| 2/2 [00:00<00:00, 4.25it/s, kld=48.4, perceptual=0.196, generator=3.64, feature=0.135, discriminator=1.28]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 68/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.19it/s, kld=100, perceptual=0.202, generator=3.63, feature=0.135, discriminator=1.29]\n", + "100%|██████| 2/2 [00:00<00:00, 4.05it/s, kld=61.3, perceptual=0.208, generator=3.65, feature=0.135, discriminator=1.27]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 69/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.29it/s, kld=116, perceptual=0.201, generator=3.64, feature=0.134, discriminator=1.29]\n", + "100%|████████| 2/2 [00:00<00:00, 4.46it/s, kld=68, perceptual=0.188, generator=3.64, feature=0.132, discriminator=1.29]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 70/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.21it/s, kld=74.7, perceptual=0.187, generator=3.63, feature=0.131, discriminator=1.3]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████| 2/2 [00:00<00:00, 3.44it/s, kld=54.3, perceptual=0.17, generator=3.64, feature=0.126, discriminator=1.29]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 71/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.21it/s, kld=102, perceptual=0.186, generator=3.65, feature=0.132, discriminator=1.27]\n", + "100%|██████| 2/2 [00:00<00:00, 4.29it/s, kld=63.4, perceptual=0.191, generator=3.66, feature=0.131, discriminator=1.25]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 72/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.17it/s, kld=161, perceptual=0.191, generator=3.64, feature=0.132, discriminator=1.29]\n", + "100%|███████| 2/2 [00:00<00:00, 4.33it/s, kld=77.6, perceptual=0.165, generator=3.64, feature=0.13, discriminator=1.27]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 73/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.23it/s, kld=102, perceptual=0.185, generator=3.61, feature=0.132, discriminator=1.31]\n", + "100%|██████| 2/2 [00:00<00:00, 4.36it/s, kld=64.2, perceptual=0.204, generator=3.64, feature=0.137, discriminator=1.29]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 74/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.18it/s, kld=71.9, perceptual=0.188, generator=3.63, feature=0.135, discriminator=1.3]\n", + "100%|██████| 2/2 [00:00<00:00, 4.13it/s, kld=50.9, perceptual=0.206, generator=3.65, feature=0.131, discriminator=1.26]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 75/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.25it/s, kld=72.6, perceptual=0.179, generator=3.65, feature=0.124, discriminator=1.26]\n", + "100%|███████| 2/2 [00:00<00:00, 4.15it/s, kld=44.3, perceptual=0.182, generator=3.58, feature=0.13, discriminator=1.34]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 76/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.20it/s, kld=165, perceptual=0.177, generator=3.64, feature=0.131, discriminator=1.3]\n", + "100%|████████| 2/2 [00:00<00:00, 4.08it/s, kld=95, perceptual=0.205, generator=3.63, feature=0.138, discriminator=1.31]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 77/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.23it/s, kld=107, perceptual=0.179, generator=3.64, feature=0.126, discriminator=1.28]\n", + "100%|███████| 2/2 [00:00<00:00, 4.41it/s, kld=56.1, perceptual=0.203, generator=3.63, feature=0.138, discriminator=1.3]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 78/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=82.6, perceptual=0.174, generator=3.63, feature=0.126, discriminator=1.29]\n", + "100%|███████| 2/2 [00:00<00:00, 4.19it/s, kld=56.4, perceptual=0.19, generator=3.64, feature=0.136, discriminator=1.26]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 79/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.29it/s, kld=88.3, perceptual=0.177, generator=3.65, feature=0.123, discriminator=1.26]\n", + "100%|███████| 2/2 [00:00<00:00, 4.33it/s, kld=55.7, perceptual=0.17, generator=3.65, feature=0.128, discriminator=1.24]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 80/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.26it/s, kld=92.4, perceptual=0.185, generator=3.62, feature=0.135, discriminator=1.31]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████| 2/2 [00:00<00:00, 3.51it/s, kld=62.6, perceptual=0.191, generator=3.66, feature=0.132, discriminator=1.24]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 81/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.27it/s, kld=93.2, perceptual=0.183, generator=3.64, feature=0.132, discriminator=1.29]\n", + "100%|███████| 2/2 [00:00<00:00, 4.39it/s, kld=64.2, perceptual=0.197, generator=3.66, feature=0.13, discriminator=1.24]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 82/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.20it/s, kld=90.5, perceptual=0.176, generator=3.65, feature=0.131, discriminator=1.28]\n", + "100%|██████| 2/2 [00:00<00:00, 4.01it/s, kld=61.5, perceptual=0.199, generator=3.67, feature=0.131, discriminator=1.23]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 83/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.26it/s, kld=89.7, perceptual=0.184, generator=3.66, feature=0.129, discriminator=1.24]\n", + "100%|██████| 2/2 [00:00<00:00, 4.12it/s, kld=61.7, perceptual=0.181, generator=3.66, feature=0.123, discriminator=1.25]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 84/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.22it/s, kld=89.6, perceptual=0.181, generator=3.65, feature=0.133, discriminator=1.25]\n", + "100%|███████| 2/2 [00:00<00:00, 4.47it/s, kld=46.8, perceptual=0.178, generator=3.65, feature=0.13, discriminator=1.26]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 85/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.27it/s, kld=96.5, perceptual=0.192, generator=3.64, feature=0.13, discriminator=1.28]\n", + "100%|███████| 2/2 [00:00<00:00, 4.30it/s, kld=66.1, perceptual=0.19, generator=3.63, feature=0.133, discriminator=1.29]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 86/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.23it/s, kld=150, perceptual=0.197, generator=3.65, feature=0.134, discriminator=1.27]\n", + "100%|████████| 2/2 [00:00<00:00, 4.41it/s, kld=72, perceptual=0.208, generator=3.62, feature=0.141, discriminator=1.32]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 87/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.26it/s, kld=119, perceptual=0.178, generator=3.63, feature=0.129, discriminator=1.3]\n", + "100%|██████| 2/2 [00:00<00:00, 4.19it/s, kld=81.4, perceptual=0.176, generator=3.63, feature=0.133, discriminator=1.28]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 88/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.23it/s, kld=102, perceptual=0.173, generator=3.64, feature=0.13, discriminator=1.26]\n", + "100%|███████| 2/2 [00:00<00:00, 4.28it/s, kld=54.9, perceptual=0.192, generator=3.63, feature=0.136, discriminator=1.3]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 89/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.22it/s, kld=105, perceptual=0.178, generator=3.63, feature=0.137, discriminator=1.29]\n", + "100%|██████| 2/2 [00:00<00:00, 4.18it/s, kld=85.9, perceptual=0.173, generator=3.64, feature=0.129, discriminator=1.28]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 90/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.16it/s, kld=78.7, perceptual=0.17, generator=3.65, feature=0.132, discriminator=1.28]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████| 2/2 [00:00<00:00, 3.52it/s, kld=50.1, perceptual=0.159, generator=3.65, feature=0.126, discriminator=1.28]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 91/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.22it/s, kld=120, perceptual=0.175, generator=3.65, feature=0.13, discriminator=1.26]\n", + "100%|██████| 2/2 [00:00<00:00, 3.88it/s, kld=54.5, perceptual=0.193, generator=3.64, feature=0.136, discriminator=1.27]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 92/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.16it/s, kld=85.1, perceptual=0.178, generator=3.67, feature=0.133, discriminator=1.22]\n", + "100%|██████| 2/2 [00:00<00:00, 4.22it/s, kld=56.2, perceptual=0.163, generator=3.66, feature=0.129, discriminator=1.24]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 93/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.19it/s, kld=98, perceptual=0.177, generator=3.64, feature=0.136, discriminator=1.27]\n", + "100%|██████| 2/2 [00:00<00:00, 4.38it/s, kld=63.8, perceptual=0.192, generator=3.66, feature=0.131, discriminator=1.24]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 94/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.22it/s, kld=117, perceptual=0.189, generator=3.64, feature=0.134, discriminator=1.29]\n", + "100%|██████| 2/2 [00:00<00:00, 4.20it/s, kld=62.4, perceptual=0.158, generator=3.63, feature=0.124, discriminator=1.29]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 95/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=115, perceptual=0.186, generator=3.67, feature=0.131, discriminator=1.22]\n", + "100%|██████| 2/2 [00:00<00:00, 4.25it/s, kld=72.7, perceptual=0.161, generator=3.64, feature=0.126, discriminator=1.28]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 96/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.16it/s, kld=83.2, perceptual=0.161, generator=3.66, feature=0.127, discriminator=1.24]\n", + "100%|████████| 2/2 [00:00<00:00, 4.27it/s, kld=76, perceptual=0.166, generator=3.65, feature=0.128, discriminator=1.27]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 97/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.21it/s, kld=100, perceptual=0.173, generator=3.64, feature=0.128, discriminator=1.27]\n", + "100%|███████| 2/2 [00:00<00:00, 3.91it/s, kld=52.8, perceptual=0.16, generator=3.64, feature=0.122, discriminator=1.27]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 98/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.19it/s, kld=105, perceptual=0.188, generator=3.63, feature=0.133, discriminator=1.29]\n", + "100%|██████| 2/2 [00:00<00:00, 4.30it/s, kld=66.9, perceptual=0.182, generator=3.66, feature=0.131, discriminator=1.22]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 99/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.24it/s, kld=107, perceptual=0.164, generator=3.64, feature=0.126, discriminator=1.3]\n", + "100%|███████| 2/2 [00:00<00:00, 3.99it/s, kld=63.3, perceptual=0.16, generator=3.66, feature=0.129, discriminator=1.25]\n" + ] + } + ], + "source": [ + "net = net.to(device)\n", + "discriminator = discriminator.to(device)\n", + "torch.autograd.set_detect_anomaly(True)\n", + "losses = {\"kld\": [], \"perceptual\": [], \"feature\": [], \"generator\": [], \"discriminator\": []}\n", + "losses_val = {\"kld\": [], \"perceptual\": [], \"feature\": [], \"generator\": [], \"discriminator\": []}\n", + "for epoch in range(num_epochs):\n", + " print(\"Epoch %d/%d\" % (epoch, num_epochs))\n", + " train_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=120)\n", + " losses_epoch = {\"kld\": 0, \"perceptual\": 0, \"feature\": 0, \"generator\": 0, \"discriminator\": 0}\n", + " for step, d in train_bar:\n", + " image = d[\"image\"].to(device)\n", + " with torch.no_grad():\n", + " label = one_hot(d[\"label\"], 6).to(device)\n", + " optimizer_G.zero_grad()\n", + "\n", + " # Losses gen\n", + " out, kld_loss = net(label, image)\n", + " disc_fakes, features_fakes = discriminator(torch.cat([out, label], 1))\n", + " loss_g = adversarial_loss(disc_fakes, target_is_real=True, for_discriminator=False)\n", + " disc_reals, features_reals = discriminator(torch.cat([image, label], 1))\n", + " loss_feat = feature_loss(features_fakes, features_reals, lambda_feat, device)\n", + " loss_perc = perceptual_loss(out, target=image)\n", + " total_loss = loss_g + loss_perc * lambda_perc + kld_loss * lambda_kld + loss_feat * lambda_feat\n", + " total_loss.backward(retain_graph=True)\n", + " optimizer_G.step()\n", + "\n", + " # Store\n", + " losses_epoch[\"kld\"] += kld_loss.item()\n", + " losses_epoch[\"perceptual\"] += loss_perc.item()\n", + " losses_epoch[\"generator\"] += loss_g.item()\n", + " # Train disc\n", + " out, _ = net(label, image)\n", + " disc_fakes, _ = discriminator(torch.cat([out, label], 1))\n", + " loss_d_r = adversarial_loss(disc_reals, target_is_real=True, for_discriminator=True)\n", + " loss_g_f = adversarial_loss(disc_fakes, target_is_real=False, for_discriminator=True)\n", + " optimizer_D.zero_grad()\n", + " loss_d = loss_d_r + loss_g_f\n", + " loss_d.backward()\n", + " optimizer_D.step()\n", + "\n", + " # Store\n", + " losses_epoch[\"feature\"] = loss_feat.item()\n", + " losses_epoch[\"discriminator\"] = loss_d_r.item() + loss_g_f.item()\n", + "\n", + " train_bar.set_postfix(\n", + " {\n", + " \"kld\": kld_loss.item(),\n", + " \"perceptual\": loss_perc.item(),\n", + " \"generator\": loss_g.item(),\n", + " \"feature\": loss_feat.item(),\n", + " \"discriminator\": loss_d_r.item() + loss_g_f.item(),\n", + " }\n", + " )\n", + "\n", + " val_bar = tqdm(enumerate(val_loader), total=len(val_loader), ncols=120)\n", + " losses_epoch_val = {\"kld\": 0, \"perceptual\": 0, \"feature\": 0, \"generator\": 0, \"discriminator\": 0}\n", + " for step, d in val_bar:\n", + " image = d[\"image\"].to(device)\n", + " with torch.no_grad():\n", + " label = one_hot(d[\"label\"], 6).to(device)\n", + " # Losses gen\n", + " out, kld_loss = net(label, image)\n", + " disc_fakes, features_fakes = discriminator(torch.cat([out, label], 1))\n", + " loss_g = adversarial_loss(disc_fakes, target_is_real=True, for_discriminator=False)\n", + " disc_reals, features_reals = discriminator(torch.cat([image, label], 1))\n", + " loss_feat = feature_loss(features_fakes, features_reals, lambda_feat, device)\n", + " loss_perc = perceptual_loss(out, target=image)\n", + " total_loss = loss_adv * loss_g + loss_perc * lambda_perc + kld_loss * lambda_kld + loss_feat * lambda_feat\n", + " # Store\n", + " losses_epoch_val[\"kld\"] += kld_loss.item()\n", + " losses_epoch_val[\"perceptual\"] += loss_perc.item()\n", + " losses_epoch_val[\"generator\"] += loss_g.item()\n", + " # Train disc\n", + " out, _ = net(label, image)\n", + " disc_fakes, _ = discriminator(torch.cat([out, label], 1))\n", + " loss_d_r = adversarial_loss(disc_reals, target_is_real=True, for_discriminator=True)\n", + " loss_g_f = adversarial_loss(disc_fakes, target_is_real=False, for_discriminator=True)\n", + " loss_d = loss_adv * (loss_d_r + loss_g_f)\n", + "\n", + " # Store\n", + " losses_epoch_val[\"feature\"] = loss_feat.item()\n", + " losses_epoch_val[\"discriminator\"] = loss_d_r.item() + loss_g_f.item()\n", + "\n", + " val_bar.set_postfix(\n", + " {\n", + " \"kld\": kld_loss.item(),\n", + " \"perceptual\": loss_perc.item(),\n", + " \"generator\": loss_g.item(),\n", + " \"feature\": loss_feat.item(),\n", + " \"discriminator\": loss_d_r.item() + loss_g_f.item(),\n", + " }\n", + " )\n", + " if step == 0 and epoch % 10 == 0:\n", + " picture_results(label, image, out)\n", + " for key, val in losses_epoch.items():\n", + " losses[key].append(val / len(train_loader))\n", + " for key, val in losses_epoch_val.items():\n", + " losses_val[key].append(val / len(val_loader))" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "d79c8ced", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Plot losses\n", + "colors = [\"orangered\", \"royalblue\", \"hotpink\", \"lime\", \"goldenrod\"]\n", + "plt.figure(figsize=(5, 10))\n", + "ind = 0\n", + "for key, val in losses.items():\n", + " plt.subplot(len(losses.keys()), 1, ind + 1)\n", + " plt.plot(val, color=colors[ind], linestyle=\"-\")\n", + " plt.plot(losses_val[key], color=colors[ind], linestyle=\"--\")\n", + " plt.title(key)\n", + " plt.xlabel(\"Epochs\")\n", + " ind += 1\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "c3cf096f", + "metadata": { + "pycharm": { + "name": "#%%" + } + }, + "source": [ + "**Conclusion**: from early on, the network shows the capability of discern between the different semantic layers. To achieve good image quality, more images and training time are needed (to avoid overfitting, seen in some loss plots of previous example), as well as thorough optimisation, such as establishing an adversarial schedule that makes sure that the discriminator and generator and the discriminator are trained only when their performance does not exceed a certain limit.\n" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,py:light" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/generative/2d_spade_gan/2d_spade_vae.py b/tutorials/generative/2d_spade_gan/2d_spade_vae.py new file mode 100644 index 00000000..ec286f29 --- /dev/null +++ b/tutorials/generative/2d_spade_gan/2d_spade_vae.py @@ -0,0 +1,360 @@ +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:light +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.14.4 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# + +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# - + +# # SPADE VAE-GAN + +# This notebook creates a mock SPADE VAE-GAN based on the paper "Semantic Image Synthesis with Spatially-Adaptive Normalization" (2019) by Park T, Liu MY, Wang TC, Zhu JY. More information available at: https://github.com/NVlabs/SPADE + +import os +import tempfile +import matplotlib.pyplot as plt +import numpy as np +import torch +from pathlib import Path +import zipfile +import gdown +from monai.data import DataLoader +from tqdm import tqdm +from generative.losses import PatchAdversarialLoss, PerceptualLoss +from generative.networks.nets import MultiScalePatchDiscriminator +import numpy as np +import monai +from generative.networks.nets.spade_network import SPADE_Net + +directory = os.environ.get("MONAI_DATA_DIRECTORY") +root_dir = tempfile.mkdtemp() if directory is None else directory +root_dir = Path(root_dir) +print("Temporary directory used: %s " % root_dir) + +# INPUT PARAMETERS +input_shape = [128, 128] +batch_size = 6 +num_workers = 4 +num_epochs = 100 +lambda_perc = 1.0 +lambda_feat = 0.1 +lambda_kld = 0.00001 +loss_adv = 1.0 + +# ### Data + +# The data for this notebook comes from the public dataset OASIS (Open Access Series of Imaging Studies) [1]. The images have been registered to MNI space using ANTsPy, and then subsampled to 2mm isotropic resolution. Geodesic Information Flows (GIF) [2] has been used to segment 5 regions: cerebrospinal fluid (CSF), grey matter (GM), white matter (WM), deep grey matter (DGM) and brainstem. In addition, BaMos [3] has been used to provide white matter hyperintensities segmentations (WMH). The available dataset contains: +# - T1-weighted images +# - FLAIR weighted images +# - Segmentations with the following labels: 0 (background), 1 (CSF), 2 (GM), 3 (WM), 4 (DGM), 5 (brainstem) and 6 (WMH). +# +# _**Acknowledgments**: "Data were provided by OASIS-3: Longitudinal Multimodal Neuroimaging: Principal Investigators: T. Benzinger, D. Marcus, J. Morris; NIH P30 AG066444, P50 AG00561, P30 NS09857781, P01 AG026276, P01 AG003991, R01 AG043434, UL1 TR000448, R01 EB009352. AV-45 doses were provided by Avid Radiopharmaceuticals, a wholly owned subsidiary of Eli Lilly.”_ +# +# +# Citations: +# +# [1] Marcus, DS, Wang, TH, Parker, J, Csernansky, JG, Morris, JC, Buckner. Open Access Series of Imaging Studies (OASIS): Cross-Sectional MRI Data in Young, Middle Aged, Nondemented, and Demented Older Adults, RL. Journal of Cognitive Neuroscience, 19, 1498-1507. doi: 10.1162/jocn.2007.19.9.1498 +# +# [2] Cardoso MJ, Modat M, Wolz R, Melbourne A, Cash D, Rueckert D, Ourselin S. Geodesic Information Flows: Spatially-Variant Graphs and Their Application to Segmentation and Fusion. IEEE Trans Med Imaging. 2015 Sep;34(9):1976-88. doi: 10.1109/TMI.2015.2418298. Epub 2015 Apr 14. PMID: 25879909. +# +# [3] Fiford CM, Sudre CH, Pemberton H, Walsh P, Manning E, Malone IB, Nicholas J, Bouvy WH, Carmichael OT, Biessels GJ, Cardoso MJ, Barnes J; Alzheimer’s Disease Neuroimaging Initiative. Automated White Matter Hyperintensity Segmentation Using Bayesian Model Selection: Assessment and Correlations with Cognitive Change. Neuroinformatics. 2020 Jun;18(3):429-449. doi: 10.1007/s12021-019-09439-6. PMID: 32062817; PMCID: PMC7338814. +# + +gdown.download( + "https://drive.google.com/uc?export=download&id=1SX_MCzQe-vyq09QYxECk32wZ2vxp9rx5", str(root_dir / "data.zip") +) + +zip_obj = zipfile.ZipFile(os.path.join(root_dir, "data.zip"), "r") +zip_obj.extractall(root_dir) +images_T1 = root_dir / "OASIS_SMALL-SUBSET/T1" +images_FLAIR = root_dir / "OASIS_SMALL-SUBSET/FLAIR" +labels = root_dir / "OASIS_SMALL-SUBSET/Segmentations" + +# We create the data dictionaries that we need +all_images = [os.path.join(images_T1, i) for i in os.listdir(images_T1)] + [ + os.path.join(images_FLAIR, i) for i in os.listdir(images_FLAIR) +] +np.random.shuffle(all_images) +corresponding_labels = [ + os.path.join(labels, i.split("/")[-1].replace(i.split("/")[-1].split("_")[0], "Parcellation")) for i in all_images +] +input_dict = [{"image": i, "label": corresponding_labels[ind]} for ind, i in enumerate(all_images)] +input_dict_train = input_dict[: int(len(input_dict) * 0.9)] +input_dict_val = input_dict[int(len(input_dict) * 0.9) :] + +# ### Dataloaders + +# + +preliminar_shape = input_shape + [50] # We take random slices fron the center of the brain +crop_shape = input_shape + [1] +base_transforms = [ + monai.transforms.LoadImaged(keys=["label", "image"]), + monai.transforms.EnsureChannelFirstd(keys=["image", "label"]), + monai.transforms.CenterSpatialCropd(keys=["label", "image"], roi_size=preliminar_shape), + monai.transforms.RandSpatialCropd(keys=["label", "image"], roi_size=crop_shape, max_roi_size=crop_shape), + monai.transforms.SqueezeDimd(keys=["label", "image"], dim=-1), + monai.transforms.Resized(keys=["image", "label"], spatial_size=input_shape), +] +last_transforms = [ + monai.transforms.CopyItemsd(keys=["label"], names=["label_channel"]), + monai.transforms.Lambdad(keys=["label_channel"], func=lambda l: l != 0), + monai.transforms.MaskIntensityd(keys=["image"], mask_key="label_channel"), + monai.transforms.NormalizeIntensityd(keys=["image"]), + monai.transforms.ToTensord(keys=["image", "label"]), +] + +aug_transforms = [ + monai.transforms.RandBiasFieldd(coeff_range=(0, 0.005), prob=0.33, keys=["image"]), + monai.transforms.RandAdjustContrastd(gamma=(0.9, 1.15), prob=0.33, keys=["image"]), + monai.transforms.RandGaussianNoised(prob=0.33, mean=0.0, std=np.random.uniform(0.005, 0.015), keys=["image"]), + monai.transforms.RandAffined( + rotate_range=[-0.05, 0.05], + shear_range=[0.001, 0.05], + scale_range=[0, 0.05], + padding_mode="zeros", + mode="nearest", + prob=0.33, + keys=["label", "image"], + ), +] + +train_transforms = monai.transforms.Compose(base_transforms + aug_transforms + last_transforms) +val_transforms = monai.transforms.Compose(base_transforms + last_transforms) + +train_dataset = monai.data.dataset.Dataset(input_dict_train, train_transforms) +val_dataset = monai.data.dataset.Dataset(input_dict_val, val_transforms) +train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers) +val_loader = DataLoader(val_dataset, shuffle=False, drop_last=False, batch_size=batch_size, num_workers=num_workers) + + +# - + +# Sanity check +batch = next(iter(train_loader)) +print(batch["image"].shape) +plt.subplot(1, 2, 1) +plt.imshow(batch["image"][0, 0, ...], cmap="gist_gray") +plt.axis("off") +plt.subplot(1, 2, 2) +plt.imshow(batch["label"][0, 0, ...], cmap="jet") +plt.axis("off") +plt.show() + +# ### Network creation and losses + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +def one_hot(input_label, label_nc): + # One hot encoding function for the labels + shape_ = list(input_label.shape) + shape_[1] = label_nc + label_out = torch.zeros(shape_) + for channel in range(label_nc): + label_out[:, channel, ...] = input_label[:, 0, ...] == channel + return label_out + + +def picture_results(input_label, input_image, output_image): + f = plt.figure(figsize=(4, 1.5)) + plt.subplot(1, 3, 1) + plt.imshow(torch.argmax(input_label, 1)[0, ...].detach().cpu(), cmap="jet") + plt.axis("off") + plt.title("Label") + plt.subplot(1, 3, 2) + plt.imshow(input_image[0, 0, ...].detach().cpu(), cmap="gist_gray") + plt.axis("off") + plt.title("Input image") + plt.subplot(1, 3, 3) + plt.imshow(output_image[0, 0, ...].detach().cpu(), cmap="gist_gray") + plt.axis("off") + plt.title("Output image") + plt.show() + + +def feature_loss(input_features_disc_fake, input_features_disc_real, lambda_feat, device): + criterion = torch.nn.L1Loss() + num_D = len(input_features_disc_fake) + GAN_Feat_loss = torch.zeros(1).to(device) + for i in range(num_D): # for each discriminator + num_intermediate_outputs = len(input_features_disc_fake[i]) + for j in range(num_intermediate_outputs): # for each layer output + unweighted_loss = criterion(input_features_disc_fake[i][j], input_features_disc_real[i][j].detach()) + GAN_Feat_loss += unweighted_loss * lambda_feat / num_D + return GAN_Feat_loss + + +net = SPADE_Net( + spatial_dims=2, + in_channels=1, + out_channels=1, + label_nc=6, + input_shape=input_shape, + num_channels=[16, 32, 64, 128], + z_dim=16, + is_vae=True, +) + +# + +discriminator = MultiScalePatchDiscriminator( + num_d=2, + num_layers_d=3, + spatial_dims=2, + num_channels=8, + in_channels=7, + out_channels=7, + minimum_size_im=128, + norm="INSTANCE", + kernel_size=3, +) + +adversarial_loss = PatchAdversarialLoss(reduction="sum", criterion="hinge") +# - + +perceptual_loss = PerceptualLoss(spatial_dims=2, network_type="vgg", is_fake_3d=False, pretrained=True) +perceptual_loss = perceptual_loss.to(device) + +optimizer_G = torch.optim.Adam(net.parameters(), lr=0.0002) +optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0004) + +# ### Training loop +# + +net = net.to(device) +discriminator = discriminator.to(device) +torch.autograd.set_detect_anomaly(True) +losses = {"kld": [], "perceptual": [], "feature": [], "generator": [], "discriminator": []} +losses_val = {"kld": [], "perceptual": [], "feature": [], "generator": [], "discriminator": []} +for epoch in range(num_epochs): + print("Epoch %d/%d" % (epoch, num_epochs)) + train_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=120) + losses_epoch = {"kld": 0, "perceptual": 0, "feature": 0, "generator": 0, "discriminator": 0} + for step, d in train_bar: + image = d["image"].to(device) + with torch.no_grad(): + label = one_hot(d["label"], 6).to(device) + optimizer_G.zero_grad() + + # Losses gen + out, kld_loss = net(label, image) + disc_fakes, features_fakes = discriminator(torch.cat([out, label], 1)) + loss_g = adversarial_loss(disc_fakes, target_is_real=True, for_discriminator=False) + disc_reals, features_reals = discriminator(torch.cat([image, label], 1)) + loss_feat = feature_loss(features_fakes, features_reals, lambda_feat, device) + loss_perc = perceptual_loss(out, target=image) + total_loss = loss_g + loss_perc * lambda_perc + kld_loss * lambda_kld + loss_feat * lambda_feat + total_loss.backward(retain_graph=True) + optimizer_G.step() + + # Store + losses_epoch["kld"] += kld_loss.item() + losses_epoch["perceptual"] += loss_perc.item() + losses_epoch["generator"] += loss_g.item() + # Train disc + out, _ = net(label, image) + disc_fakes, _ = discriminator(torch.cat([out, label], 1)) + loss_d_r = adversarial_loss(disc_reals, target_is_real=True, for_discriminator=True) + loss_g_f = adversarial_loss(disc_fakes, target_is_real=False, for_discriminator=True) + optimizer_D.zero_grad() + loss_d = loss_d_r + loss_g_f + loss_d.backward() + optimizer_D.step() + + # Store + losses_epoch["feature"] = loss_feat.item() + losses_epoch["discriminator"] = loss_d_r.item() + loss_g_f.item() + + train_bar.set_postfix( + { + "kld": kld_loss.item(), + "perceptual": loss_perc.item(), + "generator": loss_g.item(), + "feature": loss_feat.item(), + "discriminator": loss_d_r.item() + loss_g_f.item(), + } + ) + + val_bar = tqdm(enumerate(val_loader), total=len(val_loader), ncols=120) + losses_epoch_val = {"kld": 0, "perceptual": 0, "feature": 0, "generator": 0, "discriminator": 0} + for step, d in val_bar: + image = d["image"].to(device) + with torch.no_grad(): + label = one_hot(d["label"], 6).to(device) + # Losses gen + out, kld_loss = net(label, image) + disc_fakes, features_fakes = discriminator(torch.cat([out, label], 1)) + loss_g = adversarial_loss(disc_fakes, target_is_real=True, for_discriminator=False) + disc_reals, features_reals = discriminator(torch.cat([image, label], 1)) + loss_feat = feature_loss(features_fakes, features_reals, lambda_feat, device) + loss_perc = perceptual_loss(out, target=image) + total_loss = loss_adv * loss_g + loss_perc * lambda_perc + kld_loss * lambda_kld + loss_feat * lambda_feat + # Store + losses_epoch_val["kld"] += kld_loss.item() + losses_epoch_val["perceptual"] += loss_perc.item() + losses_epoch_val["generator"] += loss_g.item() + # Train disc + out, _ = net(label, image) + disc_fakes, _ = discriminator(torch.cat([out, label], 1)) + loss_d_r = adversarial_loss(disc_reals, target_is_real=True, for_discriminator=True) + loss_g_f = adversarial_loss(disc_fakes, target_is_real=False, for_discriminator=True) + loss_d = loss_adv * (loss_d_r + loss_g_f) + + # Store + losses_epoch_val["feature"] = loss_feat.item() + losses_epoch_val["discriminator"] = loss_d_r.item() + loss_g_f.item() + + val_bar.set_postfix( + { + "kld": kld_loss.item(), + "perceptual": loss_perc.item(), + "generator": loss_g.item(), + "feature": loss_feat.item(), + "discriminator": loss_d_r.item() + loss_g_f.item(), + } + ) + if step == 0 and epoch % 10 == 0: + picture_results(label, image, out) + for key, val in losses_epoch.items(): + losses[key].append(val / len(train_loader)) + for key, val in losses_epoch_val.items(): + losses_val[key].append(val / len(val_loader)) + + +# Plot losses +colors = ["orangered", "royalblue", "hotpink", "lime", "goldenrod"] +plt.figure(figsize=(5, 10)) +ind = 0 +for key, val in losses.items(): + plt.subplot(len(losses.keys()), 1, ind + 1) + plt.plot(val, color=colors[ind], linestyle="-") + plt.plot(losses_val[key], color=colors[ind], linestyle="--") + plt.title(key) + plt.xlabel("Epochs") + ind += 1 +plt.tight_layout() +plt.show() + +# + [markdown] pycharm={"name": "#%%"} +# **Conclusion**: from early on, the network shows the capability of discern between the different semantic layers. To achieve good image quality, more images and training time are needed (to avoid overfitting, seen in some loss plots of previous example), as well as thorough optimisation, such as establishing an adversarial schedule that makes sure that the discriminator and generator and the discriminator are trained only when their performance does not exceed a certain limit. +# +# -