From 18e47e3a454409de9e95315e41bd75a5d26f8f82 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Tue, 6 Jun 2023 09:02:29 +0100 Subject: [PATCH 1/7] Code for SPADE VAE-GAN added, trimmed. Tests added (although at the moment, 3 /4 get executed). Addition of SPADE norm block and auxiliary KLD loss. --- generative/losses/kld_loss.py | 6 + generative/networks/blocks/spade_norm.py | 80 ++++++ generative/networks/nets/spade_network.py | 316 ++++++++++++++++++++++ tests/test_spade_vaegan.py | 113 ++++++++ 4 files changed, 515 insertions(+) create mode 100644 generative/losses/kld_loss.py create mode 100644 generative/networks/blocks/spade_norm.py create mode 100644 generative/networks/nets/spade_network.py create mode 100644 tests/test_spade_vaegan.py diff --git a/generative/losses/kld_loss.py b/generative/losses/kld_loss.py new file mode 100644 index 00000000..4b7e6f31 --- /dev/null +++ b/generative/losses/kld_loss.py @@ -0,0 +1,6 @@ +import torch.nn as nn +import torch + +class KLDLoss(nn.Module): + def forward(self, mu, logvar): + return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) diff --git a/generative/networks/blocks/spade_norm.py b/generative/networks/blocks/spade_norm.py new file mode 100644 index 00000000..68991f0c --- /dev/null +++ b/generative/networks/blocks/spade_norm.py @@ -0,0 +1,80 @@ +from __future__ import annotations +import torch +import torch.nn as nn +from monai.networks.blocks import Convolution, ADN +import torch.nn.functional as F + +class SPADE(nn.Module): + """ + SPADE normalisation block based on the 2019 paper by Park et al. (doi: https://doi.org/10.48550/arXiv.1903.07291) + Args: + label_nc: number of semantic labels + norm_nc: number of output channels + kernel_size: kernel size + spatial_dims: number of spatial dimensions + hidden_channels: number of channels in the intermediate gamma and beta layers + normalisation: type of base normalisation used before applying the SPADE normalisation + """ + def __init__(self, + label_nc: int, + norm_nc: int, + kernel_size: int = 3, + spatial_dims: int = 2, + hidden_channels: int = 64, + norm: str | tuple= "INSTANCE", + norm_params: dict = {} + )-> None: + + super().__init__() + + if len(norm_params) != 0: + norm = (norm, norm_params) + self.param_free_norm = ADN(act=None, dropout=0.0, norm = norm, + norm_dim=spatial_dims, + ordering="N", + in_channels=norm_nc) + self.mlp_shared = Convolution(spatial_dims=spatial_dims, + in_channels = label_nc, + out_channels = hidden_channels, + kernel_size= kernel_size, + norm = None, + padding=kernel_size//2, + act="LEAKYRELU") + self.mlp_gamma = Convolution(spatial_dims=spatial_dims, + in_channels=hidden_channels, + out_channels=norm_nc, + kernel_size=kernel_size, + padding = kernel_size//2, + act = None + ) + self.mlp_beta = Convolution(spatial_dims=spatial_dims, + in_channels=hidden_channels, + out_channels=norm_nc, + kernel_size=kernel_size, + padding = kernel_size//2, + act = None + ) + + + def forward(self, + x: torch.Tensor, + segmap: torch.Tensor) -> torch.Tensor: + """ + Args: + x: input tensor + segmap: input segmentation map (bxcx[spatial-dimensions]) where c is the number of semantic channels. + The map will be interpolated to the dimension of x internally. + Returns: + + """ + + # Part 1. generate parameter-free normalized activations + normalized = self.param_free_norm(x) + + # Part 2. produce scaling and bias conditioned on semantic map + segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + out = normalized * (1 + gamma) + beta + return out diff --git a/generative/networks/nets/spade_network.py b/generative/networks/nets/spade_network.py new file mode 100644 index 00000000..d28a069d --- /dev/null +++ b/generative/networks/nets/spade_network.py @@ -0,0 +1,316 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from typing import Union, Sequence +import numpy as np +from monai.networks.blocks import Convolution +from monai.networks.layers import Act +from networks.blocks.spade_norm import SPADE +from monai.utils.enums import StrEnum +import torch.nn.functional as F +from losses.kld_loss import KLDLoss + +class UpsamplingModes(StrEnum): + bicubic = "bicubic" + nearest = "nearest" + bilinear = "bilinear" + +class SPADE_ResNetBlock(nn.Module): + + def __init__(self, + spatial_dims: int, + in_channels: int, + out_channels: int, + label_nc: int, + spade_intermediate_channels: int = 128, + norm: Union[str, tuple] = "INSTANCE", + kernel_size: int = 3,): + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.int_channels = min(self.in_channels, self.out_channels) + self.learned_shortcut = self.in_channels != self.out_channels + self.conv_0 = Convolution(spatial_dims = spatial_dims, + in_channels = self.in_channels, + out_channels = self.int_channels, + act = None, + norm = None, + ) + self.conv_1 = Convolution(spatial_dims = spatial_dims, + in_channels = self.int_channels, + out_channels = self.out_channels, + act = None, + norm = None, + ) + self.activation = nn.LeakyReLU(0.2, False) + self.norm_0 = SPADE(label_nc=label_nc, norm_nc=self.in_channels, kernel_size=kernel_size, + spatial_dims=spatial_dims, hidden_channels=spade_intermediate_channels, + norm=norm) + self.norm_1 = SPADE(label_nc=label_nc, norm_nc=self.int_channels, kernel_size=kernel_size, + spatial_dims=spatial_dims, hidden_channels=spade_intermediate_channels, + norm=norm) + + if self.learned_shortcut: + self.conv_s = Convolution(spatial_dims = spatial_dims, + in_channels = self.in_channels, + out_channels = self.out_channels, + act = None, + norm = None, + kernel_size=1, + ) + self.norm_s = SPADE(label_nc=label_nc, norm_nc=self.in_channels, kernel_size=kernel_size, + spatial_dims=spatial_dims, hidden_channels=spade_intermediate_channels, + norm=norm) + + def forward(self, x, seg): + + x_s = self.shortcut(x, seg) + dx = self.conv_0(self.activation(self.norm_0(x, seg))) + dx = self.conv_1(self.activation(self.norm_1(dx, seg))) + out = x_s + dx + return out + + def shortcut(self, x, seg): + if self.learned_shortcut: + x_s = self.conv_s(self.norm_s(x, seg)) + else: + x_s = x + return x_s + +class SPADE_Encoder(nn.Module): + + def __init__(self, + spatial_dims: int, + in_channels: int, + z_dim: int, + num_channels: Sequence[int], + input_shape: Sequence[int], + kernel_size: int = 3, + norm: Union[str, tuple] = "INSTANCE", + act: Union[str, tuple] = (Act.LEAKYRELU, {"negative_slope": 0.2})): + + super().__init__() + self.in_channels = in_channels + self.z_dim = z_dim + self.num_channels = num_channels + if len(input_shape) != spatial_dims: + raise ValueError("Length of parameter input shape must match spatial_dims; got %s" %(input_shape)) + for s_ind, s_ in enumerate(input_shape): + if s_ / (2 ** len(num_channels)) != s_ // (2 ** len(num_channels)): + raise ValueError("Each dimension of your input must be divisible by 2 ** (autoencoder depth)." + "The shape in position %d, %d is not divisible by %d. " %(s_ind, s_, len(num_channels))) + self.input_shape = input_shape + self.latent_spatial_shape = [s_ // (2 ** len(self.num_channels)) for s_ in self.input_shape] + blocks = [] + ch_init = self.in_channels + for ch_ind, ch_value in enumerate(num_channels): + blocks.append(Convolution(spatial_dims = spatial_dims, + in_channels = ch_init, + out_channels= ch_value, + strides=2, + kernel_size=kernel_size, + norm = norm, + act = act)) + ch_init = ch_value + + self.blocks = nn.ModuleList(blocks) + self.fc_mu = nn.Linear(in_features=np.prod(self.latent_spatial_shape) * self.num_channels[-1], + out_features=self.z_dim) + self.fc_var = nn.Linear(in_features=np.prod(self.latent_spatial_shape) * self.num_channels[-1], + out_features=self.z_dim) + + def forward(self, x,): + for block in self.blocks: + x = block(x) + x = x.view(x.size(0), -1) + mu = self.fc_mu(x) + logvar = self.fc_var(x) + return mu, logvar + + def encode(self, x): + + for block in self.blocks: + x = block(x) + x = x.view(x.size(0), -1) + mu = self.fc_mu(x) + logvar = self.fc_var(x) + return self.reparameterize(mu, logvar) + + def reparameterize(self, mu, logvar): + + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return eps.mul(std) + mu + +class SPADE_Decoder(nn.Module): + + def __init__(self, + spatial_dims: int, + out_channels: int, + label_nc: int, + input_shape: Sequence[int], + num_channels: Sequence[int], + z_dim: Union[int, None] = None, + is_gan: bool = False, + spade_intermediate_channels: int = 128, + norm: Union[str, tuple] = "INSTANCE", + act: Union[str, tuple, None] = (Act.LEAKYRELU, {"negative_slope": 0.2}), + last_act: Union[str, tuple, None] = (Act.LEAKYRELU, {"negative_slope": 0.2}), + kernel_size: int = 3, + upsampling_mode: str = UpsamplingModes.nearest.value, + ): + + super().__init__() + self.is_gan = is_gan + self.out_channels = out_channels + self.label_nc = label_nc + self.num_channels = num_channels + if len(input_shape) != spatial_dims: + raise ValueError("Length of parameter input shape must match spatial_dims; got %s" % (input_shape)) + for s_ind, s_ in enumerate(input_shape): + if s_ / (2 ** len(num_channels)) != s_ // (2 ** len(num_channels)): + raise ValueError("Each dimension of your input must be divisible by 2 ** (autoencoder depth)." + "The shape in position %d, %d is not divisible by %d. " % ( + s_ind, s_, len(num_channels))) + self.latent_spatial_shape = [s_ // (2 ** len(self.num_channels)) for s_ in input_shape] + + if self.is_gan: + self.fc = nn.Linear(label_nc, np.prod(self.latent_spatial_shape) * num_channels[0]) + else: + self.fc = nn.Linear(z_dim, np.prod(self.latent_spatial_shape) * num_channels[0]) + + blocks = [] + num_channels.append(self.out_channels) + self.upsampling = torch.nn.Upsample(scale_factor=2, mode=upsampling_mode) + for ch_ind, ch_value in enumerate(num_channels[:-1]): + blocks.append(SPADE_ResNetBlock(spatial_dims=spatial_dims, + in_channels=ch_value, + out_channels=num_channels[ch_ind+1], + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + norm=norm, + kernel_size=kernel_size),) + + self.blocks = torch.nn.ModuleList(blocks) + self.last_conv = Convolution(spatial_dims=spatial_dims, + in_channels=num_channels[-1], + out_channels=out_channels, + padding=(kernel_size-1)//2, + kernel_size=kernel_size, + norm = None, + act=last_act + ) + + + def forward(self, seg, z: torch.Tensor = None): + + + if self.is_gan: + x = F.interpolate(seg, size=tuple(self.latent_spatial_shape)) + x = self.fc(x) + else: + if z is None: + z = torch.randn(seg.size(0), self.opt.z_dim, + dtype=torch.float32, device=seg.get_device()) + x = self.fc(z) + x = x.view(*[-1, self.num_channels[0]]+self.latent_spatial_shape) + + for res_block in self.blocks: + x = res_block(x, seg) + x = self.upsampling(x) + + x = self.last_conv(x) + return x + +class SPADE_Net(nn.Module): + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + label_nc: int, + input_shape: Sequence[int], + num_channels: Sequence[int], + z_dim: Union[int, None] = None, + is_vae: bool = True, + spade_intermediate_channels: int = 128, + norm:Union[str, tuple] = "INSTANCE", + act: Union[str, tuple, None] = (Act.LEAKYRELU, {"negative_slope": 0.2}), + last_act: Union[str, tuple, None] = (Act.LEAKYRELU, {"negative_slope": 0.2}), + kernel_size: int = 3, + upsampling_mode: str = UpsamplingModes.nearest.value + + ): + + super().__init__() + self.is_vae = is_vae + if self.is_vae and z_dim is None: + ValueError("The latent space dimension mapped by parameter z_dim cannot be None is is_vae is True.") + + self.in_channels = in_channels + self.out_channels = out_channels + self.num_channels = num_channels + self.label_nc = label_nc + self.input_shape = input_shape + self.kld_loss = KLDLoss() + + if self.is_vae: + self.encoder = SPADE_Encoder( + spatial_dims = spatial_dims, + in_channels = in_channels, + z_dim = z_dim, + num_channels = num_channels, + input_shape = input_shape, + kernel_size = kernel_size, + norm = norm, + act = act) + + decoder_channels = num_channels + decoder_channels.reverse() + + self.decoder = SPADE_Decoder( + spatial_dims=spatial_dims, + out_channels=out_channels, + label_nc=label_nc, + input_shape=input_shape, + num_channels= decoder_channels, + z_dim = z_dim, + is_gan = not is_vae, + spade_intermediate_channels = spade_intermediate_channels, + norm = norm, + act = act, + last_act = last_act, + kernel_size=kernel_size, + upsampling_mode=upsampling_mode + ) + + def forward(self, seg: torch.Tensor, x: Union[torch.Tensor, None] = None): + + z = None + if self.is_vae: + z_mu, z_logvar = self.encoder(x) + z = self.encoder.reparameterize(z_mu, z_logvar) + kld_loss = self.kld_loss(z_mu, z_logvar) + return self.decoder(seg, z), kld_loss + else: + return self.decoder(seg, z), + + def encode(self, x: torch.Tensor): + + return self.encoder.encode(x) + + def decode(self, seg: torch.Tensor, z: Union[torch.Tensor, None] = None): + + return self.decoder(seg, z) diff --git a/tests/test_spade_vaegan.py b/tests/test_spade_vaegan.py new file mode 100644 index 00000000..8ea62b9f --- /dev/null +++ b/tests/test_spade_vaegan.py @@ -0,0 +1,113 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +import unittest +import torch +from monai.networks import eval_mode +from parameterized import parameterized +from networks.nets.spade_network import SPADE_Net +import numpy as np + +CASE_2D = [[[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], 16, True]]] +CASE_2D_BIS = [[[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], 16, True]]] +CASE_3D = [[[3, 1, 1, 3, [64, 64, 64], [16, 32, 64, 128], 16, True]]] + +def create_Semantic_Data(shape:list, semantic_regions:int): + ''' + To create semantic and image mock inputs for the network. + Args: + shape: input shape + semantic_regions: number of semantic regions + Returns: + ''' + out_label = torch.zeros(shape) + out_image = torch.zeros(shape) + torch.randn(shape)*0.01 + for i in range(1, semantic_regions): + shape_square = [i//np.random.choice(list(range(2, i//2))) for i in shape] + start_point = [np.random.choice(list(range(shape[ind]-shape_square[ind]))) + for ind, i in enumerate(shape)] + if len(shape) == 2: + out_label[start_point[0]:(start_point[0]+shape_square[0]), + start_point[1]:(start_point[1]+shape_square[1])] = i + base_intensity = torch.ones(shape_square) * np.random.randn() + out_image[start_point[0]:(start_point[0] + shape_square[0]), + start_point[1]:(start_point[1] + shape_square[1])] = base_intensity + \ + torch.randn(shape_square)*0.1 + elif len(shape) == 3: + out_label[start_point[0]:(start_point[0]+shape_square[0]), + start_point[1]:(start_point[1]+shape_square[1]), + start_point[2]:(start_point[2] + shape_square[2])] = i + base_intensity = torch.ones(shape_square) * np.random.randn() + out_image[start_point[0]:(start_point[0]+shape_square[0]), + start_point[1]:(start_point[1]+shape_square[1]), + start_point[2]:(start_point[2] + shape_square[2])] = base_intensity + \ + torch.randn(shape_square)*0.1 + else: + ValueError("Supports only 2D and 3D tensors") + + # One hot encode label + out_label_ = torch.zeros([semantic_regions,] + list(out_label.shape)) + for ch in range(semantic_regions): + out_label_[ch, ...] = out_label == ch + + return out_label_.unsqueeze(0), out_image.unsqueeze(0).unsqueeze(0) + +class TestDiffusionModelUNet2D(unittest.TestCase): + + @parameterized.expand(CASE_2D) + def test_forward_2d(self, input_param): + ''' + Check that forward method is called correctly and output shape matches. + ''' + net = SPADE_Net(*input_param) + in_label, in_image = create_Semantic_Data(input_param[4], input_param[3]) + with eval_mode(net): + out, kld = net(in_label, in_image) + self.assertEqual(False, True in torch.isnan(out) or True in torch.isinf(out) + or True in torch.isinf(kld) or True in torch.isinf(kld)) + self.assertEqual(list(out.shape), [1, 1, 64, 64]) + + @parameterized.expand(CASE_2D_BIS) + def test_encoder_decoder(self, input_param): + ''' + Check that forward method is called correctly and output shape matches. + ''' + net = SPADE_Net(*input_param) + in_label, in_image = create_Semantic_Data(input_param[4], input_param[3]) + with eval_mode(net): + out_z = net.encode(in_image) + self.assertEqual(list(out_z.shape), [1, 16]) + out_i = net.decode(in_label, out_z) + self.assertEqual(list(out_i.shape), [1, 1, 64, 64]) + + @parameterized.expand(CASE_3D) + def test_forward_2d(self, input_param): + ''' + Check that forward method is called correctly and output shape matches. + ''' + net = SPADE_Net(*input_param) + in_label, in_image = create_Semantic_Data(input_param[4], input_param[3]) + with eval_mode(net): + out, kld = net(in_label, in_image) + self.assertEqual(False, True in torch.isnan(out) or True in torch.isinf(out) + or True in torch.isinf(kld) or True in torch.isinf(kld)) + self.assertEqual(list(out.shape), [1, 1, 64, 64, 64]) + + def test_shape_wrong(self): + ''' + We input an input shape that isn't divisible by 2**(n downstream steps) + ''' + with self.assertRaises(ValueError): + net = SPADE_Net(1, 1, 8, [16, 16], [16, 32, 64, 128], 16, True) + +if __name__ == "__main__": + unittest.main() From 390eb58e921ad0ec9b87eda63c4c95d5d1c0a324 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Tue, 20 Jun 2023 22:11:51 +0100 Subject: [PATCH 2/7] Added SPADE network code, tests and jupyter notebook for 2D --- generative/networks/nets/spade_network.py | 4 +- tests/test_spade_vaegan.py | 2 +- .../2d_autoencoderkl_tutorial.ipynb | 41 ++++++++++--------- .../2d_autoencoderkl_tutorial.py | 1 - 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/generative/networks/nets/spade_network.py b/generative/networks/nets/spade_network.py index d28a069d..ebf601e6 100644 --- a/generative/networks/nets/spade_network.py +++ b/generative/networks/nets/spade_network.py @@ -15,10 +15,10 @@ import numpy as np from monai.networks.blocks import Convolution from monai.networks.layers import Act -from networks.blocks.spade_norm import SPADE +from generative.networks.blocks.spade_norm import SPADE from monai.utils.enums import StrEnum import torch.nn.functional as F -from losses.kld_loss import KLDLoss +from generative.losses.kld_loss import KLDLoss class UpsamplingModes(StrEnum): bicubic = "bicubic" diff --git a/tests/test_spade_vaegan.py b/tests/test_spade_vaegan.py index 8ea62b9f..3354d4d6 100644 --- a/tests/test_spade_vaegan.py +++ b/tests/test_spade_vaegan.py @@ -90,7 +90,7 @@ def test_encoder_decoder(self, input_param): self.assertEqual(list(out_i.shape), [1, 1, 64, 64]) @parameterized.expand(CASE_3D) - def test_forward_2d(self, input_param): + def test_forward_3d(self, input_param): ''' Check that forward method is called correctly and output shape matches. ''' diff --git a/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb b/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb index 11d059c3..2b398ee5 100644 --- a/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb +++ b/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb @@ -72,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "id": "350736c2", "metadata": {}, "outputs": [ @@ -80,29 +80,30 @@ "name": "stdout", "output_type": "stream", "text": [ - "MONAI version: 1.1.dev2239\n", - "Numpy version: 1.23.3\n", - "Pytorch version: 1.8.0+cu111\n", + "MONAI version: 1.2.dev2304\n", + "Numpy version: 1.23.5\n", + "Pytorch version: 1.13.1+cu117\n", "MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n", - "MONAI rev id: 13b24fa92b9d98bd0dc6d5cdcb52504fd09e297b\n", - "MONAI __file__: /media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.8/site-packages/monai/__init__.py\n", + "MONAI rev id: 9a57be5aab9f2c2a134768c0c146399150e247a0\n", + "MONAI __file__: /home/vf19/PycharmProjects/GenerativeModels/venv/lib/python3.9/site-packages/monai/__init__.py\n", "\n", "Optional dependencies:\n", "Pytorch Ignite version: 0.4.10\n", - "Nibabel version: 4.0.2\n", - "scikit-image version: NOT INSTALLED or UNKNOWN VERSION.\n", - "Pillow version: 9.2.0\n", - "Tensorboard version: 2.11.0\n", - "gdown version: NOT INSTALLED or UNKNOWN VERSION.\n", - "TorchVision version: 0.9.0+cu111\n", + "ITK version: 5.3.0\n", + "Nibabel version: 5.0.0\n", + "scikit-image version: 0.19.3\n", + "Pillow version: 9.4.0\n", + "Tensorboard version: 2.12.0\n", + "gdown version: 4.6.3\n", + "TorchVision version: 0.14.1+cu117\n", "tqdm version: 4.64.1\n", - "lmdb version: NOT INSTALLED or UNKNOWN VERSION.\n", - "psutil version: 5.9.3\n", - "pandas version: NOT INSTALLED or UNKNOWN VERSION.\n", + "lmdb version: 1.4.0\n", + "psutil version: 5.9.4\n", + "pandas version: 1.5.3\n", "einops version: 0.6.0\n", - "transformers version: NOT INSTALLED or UNKNOWN VERSION.\n", - "mlflow version: NOT INSTALLED or UNKNOWN VERSION.\n", - "pynrrd version: NOT INSTALLED or UNKNOWN VERSION.\n", + "transformers version: 4.21.3\n", + "mlflow version: 2.1.1\n", + "pynrrd version: 1.0.0\n", "\n", "For details about installing the optional dependencies, please visit:\n", " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n", @@ -115,7 +116,7 @@ "import shutil\n", "import tempfile\n", "import time\n", - "\n", + "from pathlib import Path\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", @@ -1229,7 +1230,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.9.16" } }, "nbformat": 4, diff --git a/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.py b/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.py index 53ccf898..39e44730 100644 --- a/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.py +++ b/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.py @@ -42,7 +42,6 @@ import shutil import tempfile import time - import matplotlib.pyplot as plt import numpy as np import torch From 97737b19a4b4a7b4778d482c3660f83741e515a9 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Tue, 20 Jun 2023 22:23:10 +0100 Subject: [PATCH 3/7] Added SPADE network code, tests and jupyter notebook for 2D --- .../2d_spade_gan/2d_spade_vae.ipynb | 937 ++++++++++++++++++ .../generative/2d_spade_gan/2d_spade_vae.py | 330 ++++++ 2 files changed, 1267 insertions(+) create mode 100644 tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb create mode 100644 tutorials/generative/2d_spade_gan/2d_spade_vae.py diff --git a/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb b/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb new file mode 100644 index 00000000..c7c29aaf --- /dev/null +++ b/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb @@ -0,0 +1,937 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 5, + "id": "1f7ba8ce", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) MONAI Consortium\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "id": "102909fb", + "metadata": {}, + "source": [ + "# SPADE VAE-GAN" + ] + }, + { + "cell_type": "markdown", + "id": "7d4cbc3c", + "metadata": {}, + "source": [ + "This notebook creates a mock SPADE VAE-GAN based on the paper \"Semantic Image Synthesis with Spatially-Adaptive Normalization\" (2019) by Park T, Liu MY, Wang TC, Zhu JY. More information available at: https://github.com/NVlabs/SPADE" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "e059c423", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "import os\n", + "import tempfile\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "from pathlib import Path\n", + "import zipfile\n", + "import gdown\n", + "from monai.data import DataLoader\n", + "from tqdm import tqdm\n", + "from generative.losses import PatchAdversarialLoss, PerceptualLoss\n", + "from generative.networks.nets import MultiScalePatchDiscriminator\n", + "import numpy as np\n", + "import monai\n", + "from generative.networks.nets.spade_network import SPADE_Net" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e76296e7", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Temporary directory used: /tmp/tmpy_otj3u5 \n" + ] + } + ], + "source": [ + "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n", + "root_dir = tempfile.mkdtemp() if directory is None else directory\n", + "root_dir = Path(root_dir)\n", + "print(\"Temporary directory used: %s \" %root_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2483148a", + "metadata": {}, + "outputs": [], + "source": [ + "# INPUT PARAMETERS\n", + "input_shape = [128, 128]\n", + "batch_size = 6\n", + "num_workers = 4\n", + "num_epochs = 100\n", + "lambda_perc = 1.0\n", + "lambda_feat = 0.1\n", + "lambda_kld = 0.00001\n", + "loss_adv = 1.0" + ] + }, + { + "cell_type": "markdown", + "id": "3f2448ba", + "metadata": {}, + "source": [ + "### Data" + ] + }, + { + "cell_type": "markdown", + "id": "2c91eec2", + "metadata": {}, + "source": [ + "The data for this notebook comes from the public dataset OASIS (Open Access Series of Imaging Studies) [1]. The images have been registered to MNI space using ANTsPy, and then subsampled to 2mm isotropic resolution. Geodesic Information Flows (GIF) [2] has been used to segment 5 regions: cerebrospinal fluid (CSF), grey matter (GM), white matter (WM), deep grey matter (DGM) and brainstem. In addition, BaMos [3] has been used to provide white matter hyperintensities segmentations (WMH). The available dataset contains:\n", + "- T1-weighted images\n", + "- FLAIR weighted images\n", + "- Segmentations with the following labels: 0 (background), 1 (CSF), 2 (GM), 3 (WM), 4 (DGM), 5 (brainstem) and 6 (WMH).\n", + "\n", + "_**Acknowledgments**: \"Data were provided by OASIS-3: Longitudinal Multimodal Neuroimaging: Principal Investigators: T. Benzinger, D. Marcus, J. Morris; NIH P30 AG066444, P50 AG00561, P30 NS09857781, P01 AG026276, P01 AG003991, R01 AG043434, UL1 TR000448, R01 EB009352. AV-45 doses were provided by Avid Radiopharmaceuticals, a wholly owned subsidiary of Eli Lilly.”_\n", + "\n", + "\n", + "Citations:\n", + "\n", + "[1] Marcus, DS, Wang, TH, Parker, J, Csernansky, JG, Morris, JC, Buckner. Open Access Series of Imaging Studies (OASIS): Cross-Sectional MRI Data in Young, Middle Aged, Nondemented, and Demented Older Adults, RL. Journal of Cognitive Neuroscience, 19, 1498-1507. doi: 10.1162/jocn.2007.19.9.1498\n", + "\n", + "[2] Cardoso MJ, Modat M, Wolz R, Melbourne A, Cash D, Rueckert D, Ourselin S. Geodesic Information Flows: Spatially-Variant Graphs and Their Application to Segmentation and Fusion. IEEE Trans Med Imaging. 2015 Sep;34(9):1976-88. doi: 10.1109/TMI.2015.2418298. Epub 2015 Apr 14. PMID: 25879909.\n", + "\n", + "[3] Fiford CM, Sudre CH, Pemberton H, Walsh P, Manning E, Malone IB, Nicholas J, Bouvy WH, Carmichael OT, Biessels GJ, Cardoso MJ, Barnes J; Alzheimer’s Disease Neuroimaging Initiative. Automated White Matter Hyperintensity Segmentation Using Bayesian Model Selection: Assessment and Correlations with Cognitive Change. Neuroinformatics. 2020 Jun;18(3):429-449. doi: 10.1007/s12021-019-09439-6. PMID: 32062817; PMCID: PMC7338814.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "dc560f7e", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Downloading...\n", + "From: https://drive.google.com/uc?export=download&id=1SX_MCzQe-vyq09QYxECk32wZ2vxp9rx5\n", + "To: /tmp/tmpy_otj3u5/data.zip\n", + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 384M/384M [00:06<00:00, 62.8MB/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "'/tmp/tmpy_otj3u5/data.zip'" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gdown.download(\"https://drive.google.com/uc?export=download&id=1SX_MCzQe-vyq09QYxECk32wZ2vxp9rx5\",\n", + " str(root_dir / 'data.zip'))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "cd7dd6ec", + "metadata": {}, + "outputs": [], + "source": [ + "zip_obj = zipfile.ZipFile(os.path.join(root_dir, 'data.zip'), 'r')\n", + "zip_obj.extractall(root_dir)\n", + "images_T1 = root_dir / \"OASIS_SMALL-SUBSET/T1\"\n", + "images_FLAIR = root_dir / \"OASIS_SMALL-SUBSET/FLAIR\"\n", + "labels = root_dir / \"OASIS_SMALL-SUBSET/Segmentations\"" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "d48987b9", + "metadata": {}, + "outputs": [], + "source": [ + "# We create the data dictionaries that we need\n", + "all_images = [os.path.join(images_T1, i) for i in os.listdir(images_T1)] + \\\n", + " [os.path.join(images_FLAIR, i) for i in os.listdir(images_FLAIR)]\n", + "np.random.shuffle(all_images)\n", + "corresponding_labels = [os.path.join(labels, i.split(\"/\")[-1].replace(i.split(\"/\")[-1].split(\"_\")[0], \"Parcellation\"))\n", + " for i in all_images]\n", + "input_dict = [{'image': i, 'label': corresponding_labels[ind]} for ind, i in enumerate(all_images)]\n", + "input_dict_train = input_dict[:int(len(input_dict)*0.9)]\n", + "input_dict_val = input_dict[int(len(input_dict)*0.9):]" + ] + }, + { + "cell_type": "markdown", + "id": "9916ca5a", + "metadata": {}, + "source": [ + "### Dataloaders" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "8ab79cdc", + "metadata": { + "lines_to_end_of_cell_marker": 2 + }, + "outputs": [], + "source": [ + "preliminar_shape = input_shape + [50] # We take random slices fron the center of the brain\n", + "crop_shape = input_shape + [1]\n", + "base_transforms = [\n", + " monai.transforms.LoadImaged(keys = ['label', 'image']),\n", + " monai.transforms.EnsureChannelFirstd(keys=['image', 'label']),\n", + " monai.transforms.CenterSpatialCropd(keys=['label', 'image'],\n", + " roi_size=preliminar_shape),\n", + " monai.transforms.RandSpatialCropd(keys = ['label', 'image'],\n", + " roi_size=crop_shape, max_roi_size=crop_shape),\n", + " monai.transforms.SqueezeDimd(keys=['label', 'image'], dim = -1),\n", + " monai.transforms.Resized(keys = ['image', 'label'], spatial_size=input_shape),\n", + "]\n", + "last_transforms = [\n", + " monai.transforms.CopyItemsd(keys=['label'], names=['label_channel']),\n", + " monai.transforms.Lambdad(keys=['label_channel'],\n", + " func=lambda l: l != 0),\n", + " monai.transforms.MaskIntensityd(keys=['image'], mask_key='label_channel'),\n", + " monai.transforms.NormalizeIntensityd(keys=['image']),\n", + " monai.transforms.ToTensord(keys=['image', 'label'])\n", + " ]\n", + "\n", + "aug_transforms = [\n", + " monai.transforms.RandBiasFieldd(coeff_range=(0, 0.005), prob=0.33, keys=['image']),\n", + " monai.transforms.RandAdjustContrastd(gamma=(0.9, 1.15), prob=0.33, keys=['image']),\n", + " monai.transforms.RandGaussianNoised(prob=0.33, mean=0.0, std=np.random.uniform(0.005, 0.015),\n", + " keys=['image']),\n", + " monai.transforms.RandAffined(rotate_range=[-0.05, 0.05], shear_range=[0.001, 0.05],\n", + " scale_range=[0, 0.05], padding_mode='zeros',\n", + " mode='nearest', prob=0.33, keys=['label', 'image'])\n", + " ]\n", + "\n", + "train_transforms = monai.transforms.Compose(base_transforms + aug_transforms + last_transforms)\n", + "val_transforms = monai.transforms.Compose(base_transforms + last_transforms)\n", + "\n", + "train_dataset = monai.data.dataset.Dataset(input_dict_train, train_transforms)\n", + "val_dataset = monai.data.dataset.Dataset(input_dict_val, val_transforms)\n", + "train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)\n", + "val_loader = DataLoader(val_dataset, shuffle=False, drop_last=False, batch_size=batch_size, num_workers=num_workers)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "98d14e75", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([6, 1, 128, 128])\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Sanity check\n", + "batch = next(iter(train_loader))\n", + "print(batch['image'].shape)\n", + "plt.subplot(1,2,1)\n", + "plt.imshow(batch['image'][0,0,...], cmap = 'gist_gray'); plt.axis('off')\n", + "plt.subplot(1,2,2)\n", + "plt.imshow(batch['label'][0,0,...], cmap = \"jet\"); plt.axis('off')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "63de4490", + "metadata": {}, + "source": [ + "### Network creation and losses" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "fa17d864", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "8f126b17", + "metadata": {}, + "outputs": [], + "source": [ + "def one_hot(input_label, label_nc):\n", + " # One hot encoding function for the labels\n", + " shape_ = list(input_label.shape)\n", + " shape_[1] = label_nc\n", + " label_out = torch.zeros(shape_)\n", + " for channel in range(label_nc):\n", + " label_out[:, channel, ...] = input_label[:, 0, ...] == channel\n", + " return label_out" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6af2779b", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "255c90c7", + "metadata": {}, + "outputs": [], + "source": [ + "def picture_results(input_label, input_image, output_image):\n", + " f = plt.figure(figsize = (4, 1.5))\n", + " plt.subplot(1,3,1)\n", + " plt.imshow(torch.argmax(input_label, 1)[0,...].detach().cpu(), cmap = 'jet')\n", + " plt.axis('off')\n", + " plt.title(\"Label\")\n", + " plt.subplot(1,3,2)\n", + " plt.imshow(input_image[0,0,...].detach().cpu(), cmap = 'gist_gray')\n", + " plt.axis('off')\n", + " plt.title(\"Input image\")\n", + " plt.subplot(1,3,3)\n", + " plt.imshow(output_image[0,0,...].detach().cpu(), cmap = 'gist_gray')\n", + " plt.axis('off')\n", + " plt.title(\"Output image\")\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eaa62145", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "c18dbad8", + "metadata": {}, + "outputs": [], + "source": [ + "def feature_loss(input_features_disc_fake, input_features_disc_real, lambda_feat, device):\n", + " criterion = torch.nn.L1Loss()\n", + " num_D = len(input_features_disc_fake)\n", + " GAN_Feat_loss = torch.zeros(1).to(device)\n", + " for i in range(num_D): # for each discriminator\n", + " num_intermediate_outputs = len(input_features_disc_fake[i])\n", + " for j in range(num_intermediate_outputs): # for each layer output\n", + " unweighted_loss = criterion(input_features_disc_fake[i][j],\n", + " input_features_disc_real[i][j].detach())\n", + " GAN_Feat_loss += unweighted_loss * lambda_feat / num_D\n", + " return GAN_Feat_loss" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "89989c34", + "metadata": {}, + "outputs": [], + "source": [ + "net = SPADE_Net(spatial_dims = 2,\n", + " in_channels = 1,\n", + " out_channels = 1,\n", + " label_nc = 6,\n", + " input_shape = input_shape,\n", + " num_channels = [16, 32, 64, 128],\n", + " z_dim = 16,\n", + " is_vae = True)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "5b8b676f", + "metadata": {}, + "outputs": [], + "source": [ + "discriminator = MultiScalePatchDiscriminator(num_d = 2,\n", + " num_layers_d = 3,\n", + " spatial_dims = 2,\n", + " num_channels = 8,\n", + " in_channels = 7,\n", + " out_channels = 7,\n", + " minimum_size_im = 128,\n", + " norm = \"INSTANCE\",\n", + " kernel_size = 3\n", + " )\n", + "\n", + "adversarial_loss = PatchAdversarialLoss(reduction = \"sum\", criterion = \"hinge\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "36ea4308", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", + "Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.\n" + ] + } + ], + "source": [ + "perceptual_loss = PerceptualLoss(spatial_dims = 2,\n", + " network_type = \"vgg\",\n", + " is_fake_3d = False,\n", + " pretrained = True)\n", + "perceptual_loss=perceptual_loss.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "3b57abd3", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer_G = torch.optim.Adam(net.parameters(), lr = 0.0002)\n", + "optimizer_D = torch.optim.Adam(discriminator.parameters(), lr = 0.0004)" + ] + }, + { + "cell_type": "markdown", + "id": "b8fde71b", + "metadata": {}, + "source": [ + "### Training loop\n" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "918eac0a", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████| 15/15 [00:08<00:00, 1.74it/s, kld=378, perceptual=0.395, generator=2.04, feature=0.206, discriminator=3.69]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████| 2/2 [00:00<00:00, 3.18it/s, kld=252, perceptual=0.362, generator=2.1, feature=0.204, discriminator=3.61]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.18it/s, kld=279, perceptual=0.395, generator=2.15, feature=0.202, discriminator=3.7]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████| 2/2 [00:00<00:00, 3.25it/s, kld=184, perceptual=0.421, generator=2.2, feature=0.196, discriminator=3.36]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.14it/s, kld=207, perceptual=0.421, generator=2.32, feature=0.214, discriminator=2.83]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████| 2/2 [00:00<00:00, 3.47it/s, kld=125, perceptual=0.414, generator=2.33, feature=0.222, discriminator=2.82]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.17it/s, kld=131, perceptual=0.384, generator=2.31, feature=0.219, discriminator=2.68]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████| 2/2 [00:00<00:00, 3.38it/s, kld=87.2, perceptual=0.403, generator=2.35, feature=0.221, discriminator=2.59]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:07<00:00, 2.12it/s, kld=170, perceptual=0.353, generator=2.46, feature=0.17, discriminator=2.51]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████| 2/2 [00:00<00:00, 3.34it/s, kld=113, perceptual=0.339, generator=2.48, feature=0.174, discriminator=2.51]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.15it/s, kld=181, perceptual=0.321, generator=2.52, feature=0.203, discriminator=2.23]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████| 2/2 [00:00<00:00, 2.59it/s, kld=151, perceptual=0.348, generator=2.49, feature=0.198, discriminator=2.25]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.17it/s, kld=179, perceptual=0.365, generator=1.58, feature=0.125, discriminator=3.83]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████| 2/2 [00:00<00:00, 3.08it/s, kld=108, perceptual=0.384, generator=1.93, feature=0.183, discriminator=3.43]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.15it/s, kld=407, perceptual=0.38, generator=1.84, feature=0.127, discriminator=3.41]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████| 2/2 [00:00<00:00, 3.35it/s, kld=302, perceptual=0.393, generator=2.02, feature=0.133, discriminator=2.79]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.15it/s, kld=356, perceptual=0.365, generator=2.34, feature=0.166, discriminator=2.39]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████| 2/2 [00:00<00:00, 3.20it/s, kld=237, perceptual=0.37, generator=2.35, feature=0.176, discriminator=2.37]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " 60%|████▏ | 9/15 [00:04<00:03, 1.87it/s, kld=194, perceptual=0.35, generator=2.63, feature=0.156, discriminator=2.57]\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)", + "Cell \u001B[0;32mIn[22], line 36\u001B[0m\n\u001B[1;32m 34\u001B[0m optimizer_D\u001B[38;5;241m.\u001B[39mzero_grad()\n\u001B[1;32m 35\u001B[0m loss_d \u001B[38;5;241m=\u001B[39m loss_d_r \u001B[38;5;241m+\u001B[39m loss_g_f\n\u001B[0;32m---> 36\u001B[0m \u001B[43mloss_d\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 37\u001B[0m optimizer_D\u001B[38;5;241m.\u001B[39mstep()\n\u001B[1;32m 39\u001B[0m \u001B[38;5;66;03m# Store\u001B[39;00m\n", + "File \u001B[0;32m~/PycharmProjects/GenerativeModels/venv/lib/python3.9/site-packages/torch/_tensor.py:479\u001B[0m, in \u001B[0;36mTensor.backward\u001B[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001B[0m\n\u001B[1;32m 432\u001B[0m \u001B[38;5;250m\u001B[39m\u001B[38;5;124mr\u001B[39m\u001B[38;5;124;03m\"\"\"Computes the gradient of current tensor w.r.t. graph leaves.\u001B[39;00m\n\u001B[1;32m 433\u001B[0m \n\u001B[1;32m 434\u001B[0m \u001B[38;5;124;03mThe graph is differentiated using the chain rule. If the tensor is\u001B[39;00m\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 476\u001B[0m \u001B[38;5;124;03m used to compute the attr::tensors.\u001B[39;00m\n\u001B[1;32m 477\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m 478\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_torch_function_unary(\u001B[38;5;28mself\u001B[39m):\n\u001B[0;32m--> 479\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mhandle_torch_function\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 480\u001B[0m \u001B[43m \u001B[49m\u001B[43mTensor\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 481\u001B[0m \u001B[43m \u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 482\u001B[0m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\n\u001B[1;32m 483\u001B[0m \u001B[43m \u001B[49m\u001B[43mgradient\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mgradient\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 484\u001B[0m \u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 485\u001B[0m \u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 486\u001B[0m \u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minputs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 487\u001B[0m \u001B[43m \u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 488\u001B[0m torch\u001B[38;5;241m.\u001B[39mautograd\u001B[38;5;241m.\u001B[39mbackward(\n\u001B[1;32m 489\u001B[0m \u001B[38;5;28mself\u001B[39m, gradient, retain_graph, create_graph, inputs\u001B[38;5;241m=\u001B[39minputs\n\u001B[1;32m 490\u001B[0m )\n", + "File \u001B[0;32m~/PycharmProjects/GenerativeModels/venv/lib/python3.9/site-packages/torch/overrides.py:1534\u001B[0m, in \u001B[0;36mhandle_torch_function\u001B[0;34m(public_api, relevant_args, *args, **kwargs)\u001B[0m\n\u001B[1;32m 1528\u001B[0m warnings\u001B[38;5;241m.\u001B[39mwarn(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mDefining your `__torch_function__ as a plain method is deprecated and \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m 1529\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mwill be an error in future, please define it as a classmethod.\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[1;32m 1530\u001B[0m \u001B[38;5;167;01mDeprecationWarning\u001B[39;00m)\n\u001B[1;32m 1532\u001B[0m \u001B[38;5;66;03m# Use `public_api` instead of `implementation` so __torch_function__\u001B[39;00m\n\u001B[1;32m 1533\u001B[0m \u001B[38;5;66;03m# implementations can do equality/identity comparisons.\u001B[39;00m\n\u001B[0;32m-> 1534\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[43mtorch_func_method\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpublic_api\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtypes\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 1536\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m result \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28mNotImplemented\u001B[39m:\n\u001B[1;32m 1537\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m result\n", + "File \u001B[0;32m~/PycharmProjects/GenerativeModels/venv/lib/python3.9/site-packages/monai/data/meta_tensor.py:276\u001B[0m, in \u001B[0;36mMetaTensor.__torch_function__\u001B[0;34m(cls, func, types, args, kwargs)\u001B[0m\n\u001B[1;32m 274\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m kwargs \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m 275\u001B[0m kwargs \u001B[38;5;241m=\u001B[39m {}\n\u001B[0;32m--> 276\u001B[0m ret \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43msuper\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m__torch_function__\u001B[49m\u001B[43m(\u001B[49m\u001B[43mfunc\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtypes\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 277\u001B[0m \u001B[38;5;66;03m# if `out` has been used as argument, metadata is not copied, nothing to do.\u001B[39;00m\n\u001B[1;32m 278\u001B[0m \u001B[38;5;66;03m# if \"out\" in kwargs:\u001B[39;00m\n\u001B[1;32m 279\u001B[0m \u001B[38;5;66;03m# return ret\u001B[39;00m\n\u001B[1;32m 280\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m _not_requiring_metadata(ret):\n", + "File \u001B[0;32m~/PycharmProjects/GenerativeModels/venv/lib/python3.9/site-packages/torch/_tensor.py:1279\u001B[0m, in \u001B[0;36mTensor.__torch_function__\u001B[0;34m(cls, func, types, args, kwargs)\u001B[0m\n\u001B[1;32m 1276\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mNotImplemented\u001B[39m\n\u001B[1;32m 1278\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m _C\u001B[38;5;241m.\u001B[39mDisableTorchFunction():\n\u001B[0;32m-> 1279\u001B[0m ret \u001B[38;5;241m=\u001B[39m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 1280\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m func \u001B[38;5;129;01min\u001B[39;00m get_default_nowrap_functions():\n\u001B[1;32m 1281\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m ret\n", + "File \u001B[0;32m~/PycharmProjects/GenerativeModels/venv/lib/python3.9/site-packages/torch/_tensor.py:488\u001B[0m, in \u001B[0;36mTensor.backward\u001B[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001B[0m\n\u001B[1;32m 478\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_torch_function_unary(\u001B[38;5;28mself\u001B[39m):\n\u001B[1;32m 479\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m handle_torch_function(\n\u001B[1;32m 480\u001B[0m Tensor\u001B[38;5;241m.\u001B[39mbackward,\n\u001B[1;32m 481\u001B[0m (\u001B[38;5;28mself\u001B[39m,),\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 486\u001B[0m inputs\u001B[38;5;241m=\u001B[39minputs,\n\u001B[1;32m 487\u001B[0m )\n\u001B[0;32m--> 488\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mautograd\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 489\u001B[0m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgradient\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minputs\u001B[49m\n\u001B[1;32m 490\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/GenerativeModels/venv/lib/python3.9/site-packages/torch/autograd/__init__.py:197\u001B[0m, in \u001B[0;36mbackward\u001B[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001B[0m\n\u001B[1;32m 192\u001B[0m retain_graph \u001B[38;5;241m=\u001B[39m create_graph\n\u001B[1;32m 194\u001B[0m \u001B[38;5;66;03m# The reason we repeat same the comment below is that\u001B[39;00m\n\u001B[1;32m 195\u001B[0m \u001B[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001B[39;00m\n\u001B[1;32m 196\u001B[0m \u001B[38;5;66;03m# calls in the traceback and some print out the last line\u001B[39;00m\n\u001B[0;32m--> 197\u001B[0m \u001B[43mVariable\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_execution_engine\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun_backward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001B[39;49;00m\n\u001B[1;32m 198\u001B[0m \u001B[43m \u001B[49m\u001B[43mtensors\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgrad_tensors_\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 199\u001B[0m \u001B[43m \u001B[49m\u001B[43mallow_unreachable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maccumulate_grad\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m)\u001B[49m\n", + "\u001B[0;31mKeyboardInterrupt\u001B[0m: " + ] + } + ], + "source": [ + "net = net.to(device)\n", + "discriminator = discriminator.to(device)\n", + "torch.autograd.set_detect_anomaly(True)\n", + "for epoch in range(num_epochs):\n", + " print(\"Epoch %d/%d\" %(epoch, num_epochs))\n", + " train_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=120)\n", + " losses_epoch = {'kld': 0, 'perceptual': 0, 'feature': 0, 'generator': 0, 'discriminator': 0}\n", + " for step, d in train_bar:\n", + " image = d['image'].to(device)\n", + " with torch.no_grad():\n", + " label = one_hot(d['label'], 6).to(device)\n", + " optimizer_G.zero_grad()\n", + "\n", + " # Losses gen\n", + " out, kld_loss = net(label, image)\n", + " disc_fakes, features_fakes = discriminator(torch.cat([out, label], 1))\n", + " loss_g = adversarial_loss(disc_fakes, target_is_real = True, for_discriminator = False)\n", + " disc_reals, features_reals = discriminator(torch.cat([image, label], 1))\n", + " loss_feat = feature_loss(features_fakes, features_reals, lambda_feat, device)\n", + " loss_perc = perceptual_loss(out, target = image)\n", + " total_loss = loss_g + loss_perc * lambda_perc + kld_loss * lambda_kld + loss_feat * lambda_feat\n", + " total_loss.backward(retain_graph = True)\n", + " optimizer_G.step()\n", + "\n", + " # Store\n", + " losses_epoch['kld'] += kld_loss.item()\n", + " losses_epoch['perceptual'] += loss_perc.item()\n", + " losses_epoch['generator'] += loss_g.item()\n", + " #Train disc\n", + " out, _ = net(label, image)\n", + " disc_fakes, _ = discriminator(torch.cat([out, label], 1))\n", + " loss_d_r = adversarial_loss(disc_reals, target_is_real = True, for_discriminator = True)\n", + " loss_g_f = adversarial_loss(disc_fakes, target_is_real = False, for_discriminator = True)\n", + " optimizer_D.zero_grad()\n", + " loss_d = loss_d_r + loss_g_f\n", + " loss_d.backward()\n", + " optimizer_D.step()\n", + "\n", + " # Store\n", + " losses_epoch['feature'] = loss_feat.item()\n", + " losses_epoch['discriminator'] = loss_d_r.item() + loss_g_f.item()\n", + "\n", + " train_bar.set_postfix(\n", + " {\"kld\": kld_loss.item(),\n", + " \"perceptual\": loss_perc.item(),\n", + " \"generator\": loss_g.item(),\n", + " \"feature\": loss_feat.item(),\n", + " \"discriminator\": loss_d_r.item() + loss_g_f.item(),\n", + " })\n", + "\n", + " val_bar = tqdm(enumerate(val_loader), total=len(val_loader), ncols=120)\n", + " losses_epoch_val = {'kld': 0, 'perceptual': 0, 'feature': 0, 'generator': 0, 'discriminator': 0}\n", + " for step, d in val_bar:\n", + " image = d['image'].to(device)\n", + " with torch.no_grad():\n", + " label = one_hot(d['label'], 6).to(device)\n", + " # Losses gen\n", + " out, kld_loss = net(label, image)\n", + " disc_fakes, features_fakes = discriminator(torch.cat([out, label], 1))\n", + " loss_g = adversarial_loss(disc_fakes, target_is_real = True, for_discriminator = False)\n", + " disc_reals, features_reals = discriminator(torch.cat([image, label], 1))\n", + " loss_feat = feature_loss(features_fakes, features_reals, lambda_feat, device)\n", + " loss_perc = perceptual_loss(out, target = image)\n", + " total_loss = loss_adv * loss_g + loss_perc * lambda_perc + kld_loss * lambda_kld + loss_feat * lambda_feat\n", + " # Store\n", + " losses_epoch_val['kld'] += kld_loss.item()\n", + " losses_epoch_val['perceptual'] += loss_perc.item()\n", + " losses_epoch_val['generator'] += loss_g.item()\n", + " #Train disc\n", + " out, _ = net(label, image)\n", + " disc_fakes, _ = discriminator(torch.cat([out, label], 1))\n", + " loss_d_r = adversarial_loss(disc_reals, target_is_real = True, for_discriminator = True)\n", + " loss_g_f = adversarial_loss(disc_fakes, target_is_real = False, for_discriminator = True)\n", + " loss_d = loss_adv * (loss_d_r + loss_g_f)\n", + "\n", + " # Store\n", + " losses_epoch_val['feature'] = loss_feat.item()\n", + " losses_epoch_val['discriminator'] = loss_d_r.item() + loss_g_f.item()\n", + "\n", + " val_bar.set_postfix(\n", + " {\"kld\": kld_loss.item(),\n", + " \"perceptual\": loss_perc.item(),\n", + " \"generator\": loss_g.item(),\n", + " \"feature\": loss_feat.item(),\n", + " \"discriminator\": loss_d_r.item() + loss_g_f.item(),\n", + " })\n", + " if step == 0 and epoch%10==0:\n", + " picture_results(label, image, out)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f579376b", + "metadata": { + "pycharm": { + "name": "#%%" + } + }, + "outputs": [], + "source": [ + "\n" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,py:light" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/tutorials/generative/2d_spade_gan/2d_spade_vae.py b/tutorials/generative/2d_spade_gan/2d_spade_vae.py new file mode 100644 index 00000000..1f2a93ac --- /dev/null +++ b/tutorials/generative/2d_spade_gan/2d_spade_vae.py @@ -0,0 +1,330 @@ +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:light +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.14.4 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# + +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# - + +# # SPADE VAE-GAN + +# This notebook creates a mock SPADE VAE-GAN based on the paper "Semantic Image Synthesis with Spatially-Adaptive Normalization" (2019) by Park T, Liu MY, Wang TC, Zhu JY. More information available at: https://github.com/NVlabs/SPADE + +import os +import tempfile +import matplotlib.pyplot as plt +import numpy as np +import torch +from pathlib import Path +import zipfile +import gdown +from monai.data import DataLoader +from tqdm import tqdm +from generative.losses import PatchAdversarialLoss, PerceptualLoss +from generative.networks.nets import MultiScalePatchDiscriminator +import numpy as np +import monai +from generative.networks.nets.spade_network import SPADE_Net + +directory = os.environ.get("MONAI_DATA_DIRECTORY") +root_dir = tempfile.mkdtemp() if directory is None else directory +root_dir = Path(root_dir) +print("Temporary directory used: %s " %root_dir) + +# INPUT PARAMETERS +input_shape = [128, 128] +batch_size = 6 +num_workers = 4 +num_epochs = 100 +lambda_perc = 1.0 +lambda_feat = 0.1 +lambda_kld = 0.00001 +loss_adv = 1.0 + +# ### Data + +# The data for this notebook comes from the public dataset OASIS (Open Access Series of Imaging Studies) [1]. The images have been registered to MNI space using ANTsPy, and then subsampled to 2mm isotropic resolution. Geodesic Information Flows (GIF) [2] has been used to segment 5 regions: cerebrospinal fluid (CSF), grey matter (GM), white matter (WM), deep grey matter (DGM) and brainstem. In addition, BaMos [3] has been used to provide white matter hyperintensities segmentations (WMH). The available dataset contains: +# - T1-weighted images +# - FLAIR weighted images +# - Segmentations with the following labels: 0 (background), 1 (CSF), 2 (GM), 3 (WM), 4 (DGM), 5 (brainstem) and 6 (WMH). +# +# _**Acknowledgments**: "Data were provided by OASIS-3: Longitudinal Multimodal Neuroimaging: Principal Investigators: T. Benzinger, D. Marcus, J. Morris; NIH P30 AG066444, P50 AG00561, P30 NS09857781, P01 AG026276, P01 AG003991, R01 AG043434, UL1 TR000448, R01 EB009352. AV-45 doses were provided by Avid Radiopharmaceuticals, a wholly owned subsidiary of Eli Lilly.”_ +# +# +# Citations: +# +# [1] Marcus, DS, Wang, TH, Parker, J, Csernansky, JG, Morris, JC, Buckner. Open Access Series of Imaging Studies (OASIS): Cross-Sectional MRI Data in Young, Middle Aged, Nondemented, and Demented Older Adults, RL. Journal of Cognitive Neuroscience, 19, 1498-1507. doi: 10.1162/jocn.2007.19.9.1498 +# +# [2] Cardoso MJ, Modat M, Wolz R, Melbourne A, Cash D, Rueckert D, Ourselin S. Geodesic Information Flows: Spatially-Variant Graphs and Their Application to Segmentation and Fusion. IEEE Trans Med Imaging. 2015 Sep;34(9):1976-88. doi: 10.1109/TMI.2015.2418298. Epub 2015 Apr 14. PMID: 25879909. +# +# [3] Fiford CM, Sudre CH, Pemberton H, Walsh P, Manning E, Malone IB, Nicholas J, Bouvy WH, Carmichael OT, Biessels GJ, Cardoso MJ, Barnes J; Alzheimer’s Disease Neuroimaging Initiative. Automated White Matter Hyperintensity Segmentation Using Bayesian Model Selection: Assessment and Correlations with Cognitive Change. Neuroinformatics. 2020 Jun;18(3):429-449. doi: 10.1007/s12021-019-09439-6. PMID: 32062817; PMCID: PMC7338814. +# + +gdown.download("https://drive.google.com/uc?export=download&id=1SX_MCzQe-vyq09QYxECk32wZ2vxp9rx5", + str(root_dir / 'data.zip')) + +zip_obj = zipfile.ZipFile(os.path.join(root_dir, 'data.zip'), 'r') +zip_obj.extractall(root_dir) +images_T1 = root_dir / "OASIS_SMALL-SUBSET/T1" +images_FLAIR = root_dir / "OASIS_SMALL-SUBSET/FLAIR" +labels = root_dir / "OASIS_SMALL-SUBSET/Segmentations" + +# We create the data dictionaries that we need +all_images = [os.path.join(images_T1, i) for i in os.listdir(images_T1)] + \ + [os.path.join(images_FLAIR, i) for i in os.listdir(images_FLAIR)] +np.random.shuffle(all_images) +corresponding_labels = [os.path.join(labels, i.split("/")[-1].replace(i.split("/")[-1].split("_")[0], "Parcellation")) + for i in all_images] +input_dict = [{'image': i, 'label': corresponding_labels[ind]} for ind, i in enumerate(all_images)] +input_dict_train = input_dict[:int(len(input_dict)*0.9)] +input_dict_val = input_dict[int(len(input_dict)*0.9):] + +# ### Dataloaders + +# + +preliminar_shape = input_shape + [50] # We take random slices fron the center of the brain +crop_shape = input_shape + [1] +base_transforms = [ + monai.transforms.LoadImaged(keys = ['label', 'image']), + monai.transforms.EnsureChannelFirstd(keys=['image', 'label']), + monai.transforms.CenterSpatialCropd(keys=['label', 'image'], + roi_size=preliminar_shape), + monai.transforms.RandSpatialCropd(keys = ['label', 'image'], + roi_size=crop_shape, max_roi_size=crop_shape), + monai.transforms.SqueezeDimd(keys=['label', 'image'], dim = -1), + monai.transforms.Resized(keys = ['image', 'label'], spatial_size=input_shape), +] +last_transforms = [ + monai.transforms.CopyItemsd(keys=['label'], names=['label_channel']), + monai.transforms.Lambdad(keys=['label_channel'], + func=lambda l: l != 0), + monai.transforms.MaskIntensityd(keys=['image'], mask_key='label_channel'), + monai.transforms.NormalizeIntensityd(keys=['image']), + monai.transforms.ToTensord(keys=['image', 'label']) + ] + +aug_transforms = [ + monai.transforms.RandBiasFieldd(coeff_range=(0, 0.005), prob=0.33, keys=['image']), + monai.transforms.RandAdjustContrastd(gamma=(0.9, 1.15), prob=0.33, keys=['image']), + monai.transforms.RandGaussianNoised(prob=0.33, mean=0.0, std=np.random.uniform(0.005, 0.015), + keys=['image']), + monai.transforms.RandAffined(rotate_range=[-0.05, 0.05], shear_range=[0.001, 0.05], + scale_range=[0, 0.05], padding_mode='zeros', + mode='nearest', prob=0.33, keys=['label', 'image']) + ] + +train_transforms = monai.transforms.Compose(base_transforms + aug_transforms + last_transforms) +val_transforms = monai.transforms.Compose(base_transforms + last_transforms) + +train_dataset = monai.data.dataset.Dataset(input_dict_train, train_transforms) +val_dataset = monai.data.dataset.Dataset(input_dict_val, val_transforms) +train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers) +val_loader = DataLoader(val_dataset, shuffle=False, drop_last=False, batch_size=batch_size, num_workers=num_workers) + + +# - + +# Sanity check +batch = next(iter(train_loader)) +print(batch['image'].shape) +plt.subplot(1,2,1) +plt.imshow(batch['image'][0,0,...], cmap = 'gist_gray'); plt.axis('off') +plt.subplot(1,2,2) +plt.imshow(batch['label'][0,0,...], cmap = "jet"); plt.axis('off') +plt.show() + +# ### Network creation and losses + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +def one_hot(input_label, label_nc): + # One hot encoding function for the labels + shape_ = list(input_label.shape) + shape_[1] = label_nc + label_out = torch.zeros(shape_) + for channel in range(label_nc): + label_out[:, channel, ...] = input_label[:, 0, ...] == channel + return label_out + + + +def picture_results(input_label, input_image, output_image): + f = plt.figure(figsize = (4, 1.5)) + plt.subplot(1,3,1) + plt.imshow(torch.argmax(input_label, 1)[0,...].detach().cpu(), cmap = 'jet') + plt.axis('off') + plt.title("Label") + plt.subplot(1,3,2) + plt.imshow(input_image[0,0,...].detach().cpu(), cmap = 'gist_gray') + plt.axis('off') + plt.title("Input image") + plt.subplot(1,3,3) + plt.imshow(output_image[0,0,...].detach().cpu(), cmap = 'gist_gray') + plt.axis('off') + plt.title("Output image") + plt.show() + + + +def feature_loss(input_features_disc_fake, input_features_disc_real, lambda_feat, device): + criterion = torch.nn.L1Loss() + num_D = len(input_features_disc_fake) + GAN_Feat_loss = torch.zeros(1).to(device) + for i in range(num_D): # for each discriminator + num_intermediate_outputs = len(input_features_disc_fake[i]) + for j in range(num_intermediate_outputs): # for each layer output + unweighted_loss = criterion(input_features_disc_fake[i][j], + input_features_disc_real[i][j].detach()) + GAN_Feat_loss += unweighted_loss * lambda_feat / num_D + return GAN_Feat_loss + + +net = SPADE_Net(spatial_dims = 2, + in_channels = 1, + out_channels = 1, + label_nc = 6, + input_shape = input_shape, + num_channels = [16, 32, 64, 128], + z_dim = 16, + is_vae = True) + +# + +discriminator = MultiScalePatchDiscriminator(num_d = 2, + num_layers_d = 3, + spatial_dims = 2, + num_channels = 8, + in_channels = 7, + out_channels = 7, + minimum_size_im = 128, + norm = "INSTANCE", + kernel_size = 3 + ) + +adversarial_loss = PatchAdversarialLoss(reduction = "sum", criterion = "hinge") +# - + +perceptual_loss = PerceptualLoss(spatial_dims = 2, + network_type = "vgg", + is_fake_3d = False, + pretrained = True) +perceptual_loss=perceptual_loss.to(device) + +optimizer_G = torch.optim.Adam(net.parameters(), lr = 0.0002) +optimizer_D = torch.optim.Adam(discriminator.parameters(), lr = 0.0004) + +# ### Training loop +# + +net = net.to(device) +discriminator = discriminator.to(device) +torch.autograd.set_detect_anomaly(True) +for epoch in range(num_epochs): + print("Epoch %d/%d" %(epoch, num_epochs)) + train_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=120) + losses_epoch = {'kld': 0, 'perceptual': 0, 'feature': 0, 'generator': 0, 'discriminator': 0} + for step, d in train_bar: + image = d['image'].to(device) + with torch.no_grad(): + label = one_hot(d['label'], 6).to(device) + optimizer_G.zero_grad() + + # Losses gen + out, kld_loss = net(label, image) + disc_fakes, features_fakes = discriminator(torch.cat([out, label], 1)) + loss_g = adversarial_loss(disc_fakes, target_is_real = True, for_discriminator = False) + disc_reals, features_reals = discriminator(torch.cat([image, label], 1)) + loss_feat = feature_loss(features_fakes, features_reals, lambda_feat, device) + loss_perc = perceptual_loss(out, target = image) + total_loss = loss_g + loss_perc * lambda_perc + kld_loss * lambda_kld + loss_feat * lambda_feat + total_loss.backward(retain_graph = True) + optimizer_G.step() + + # Store + losses_epoch['kld'] += kld_loss.item() + losses_epoch['perceptual'] += loss_perc.item() + losses_epoch['generator'] += loss_g.item() + #Train disc + out, _ = net(label, image) + disc_fakes, _ = discriminator(torch.cat([out, label], 1)) + loss_d_r = adversarial_loss(disc_reals, target_is_real = True, for_discriminator = True) + loss_g_f = adversarial_loss(disc_fakes, target_is_real = False, for_discriminator = True) + optimizer_D.zero_grad() + loss_d = loss_d_r + loss_g_f + loss_d.backward() + optimizer_D.step() + + # Store + losses_epoch['feature'] = loss_feat.item() + losses_epoch['discriminator'] = loss_d_r.item() + loss_g_f.item() + + train_bar.set_postfix( + {"kld": kld_loss.item(), + "perceptual": loss_perc.item(), + "generator": loss_g.item(), + "feature": loss_feat.item(), + "discriminator": loss_d_r.item() + loss_g_f.item(), + }) + + val_bar = tqdm(enumerate(val_loader), total=len(val_loader), ncols=120) + losses_epoch_val = {'kld': 0, 'perceptual': 0, 'feature': 0, 'generator': 0, 'discriminator': 0} + for step, d in val_bar: + image = d['image'].to(device) + with torch.no_grad(): + label = one_hot(d['label'], 6).to(device) + # Losses gen + out, kld_loss = net(label, image) + disc_fakes, features_fakes = discriminator(torch.cat([out, label], 1)) + loss_g = adversarial_loss(disc_fakes, target_is_real = True, for_discriminator = False) + disc_reals, features_reals = discriminator(torch.cat([image, label], 1)) + loss_feat = feature_loss(features_fakes, features_reals, lambda_feat, device) + loss_perc = perceptual_loss(out, target = image) + total_loss = loss_adv * loss_g + loss_perc * lambda_perc + kld_loss * lambda_kld + loss_feat * lambda_feat + # Store + losses_epoch_val['kld'] += kld_loss.item() + losses_epoch_val['perceptual'] += loss_perc.item() + losses_epoch_val['generator'] += loss_g.item() + #Train disc + out, _ = net(label, image) + disc_fakes, _ = discriminator(torch.cat([out, label], 1)) + loss_d_r = adversarial_loss(disc_reals, target_is_real = True, for_discriminator = True) + loss_g_f = adversarial_loss(disc_fakes, target_is_real = False, for_discriminator = True) + loss_d = loss_adv * (loss_d_r + loss_g_f) + + # Store + losses_epoch_val['feature'] = loss_feat.item() + losses_epoch_val['discriminator'] = loss_d_r.item() + loss_g_f.item() + + val_bar.set_postfix( + {"kld": kld_loss.item(), + "perceptual": loss_perc.item(), + "generator": loss_g.item(), + "feature": loss_feat.item(), + "discriminator": loss_d_r.item() + loss_g_f.item(), + }) + if step == 0 and epoch%10==0: + picture_results(label, image, out) + +# + pycharm={"name": "#%%"} + + From e77599e4d1728ccb4659133d480a5dd5e68f87d7 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Thu, 22 Jun 2023 09:22:17 +0100 Subject: [PATCH 4/7] Add tutorial outputs Signed-off-by: Walter Hugo Lopez Pinaya --- .../2d_spade_gan/2d_spade_vae.ipynb | 1723 +++++++++++++++-- 1 file changed, 1593 insertions(+), 130 deletions(-) diff --git a/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb b/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb index c7c29aaf..bf7271c3 100644 --- a/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb +++ b/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 5, + "execution_count": 1, "id": "1f7ba8ce", "metadata": {}, "outputs": [], @@ -37,12 +37,23 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 2, "id": "e059c423", "metadata": { "scrolled": true }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "A matching Triton is not available, some optimizations will not be enabled.\n", + "Error caught was: No module named 'triton'\n" + ] + } + ], "source": [ "import os\n", "import tempfile\n", @@ -63,7 +74,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 3, "id": "e76296e7", "metadata": { "scrolled": false @@ -73,7 +84,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Temporary directory used: /tmp/tmpy_otj3u5 \n" + "Temporary directory used: /tmp/tmpo8gppqh6 \n" ] } ], @@ -86,7 +97,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 4, "id": "2483148a", "metadata": {}, "outputs": [], @@ -134,7 +145,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 5, "id": "dc560f7e", "metadata": { "scrolled": true @@ -146,17 +157,17 @@ "text": [ "Downloading...\n", "From: https://drive.google.com/uc?export=download&id=1SX_MCzQe-vyq09QYxECk32wZ2vxp9rx5\n", - "To: /tmp/tmpy_otj3u5/data.zip\n", - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 384M/384M [00:06<00:00, 62.8MB/s]\n" + "To: /tmp/tmpo8gppqh6/data.zip\n", + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 384M/384M [00:05<00:00, 69.6MB/s]\n" ] }, { "data": { "text/plain": [ - "'/tmp/tmpy_otj3u5/data.zip'" + "'/tmp/tmpo8gppqh6/data.zip'" ] }, - "execution_count": 9, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -168,7 +179,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "id": "cd7dd6ec", "metadata": {}, "outputs": [], @@ -182,7 +193,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 7, "id": "d48987b9", "metadata": {}, "outputs": [], @@ -208,12 +219,21 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 8, "id": "8ab79cdc", "metadata": { "lines_to_end_of_cell_marker": 2 }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "monai.transforms.io.dictionary LoadImaged.__init__:image_only: Current default value of argument `image_only=False` has been deprecated since version 1.1. It will be changed to `image_only=True` in version 1.3.\n", + "monai.transforms.croppad.dictionary RandSpatialCropd.__init__:random_size: Current default value of argument `random_size=True` has been deprecated since version 1.1. It will be changed to `random_size=False` in version 1.3.\n" + ] + } + ], "source": [ "preliminar_shape = input_shape + [50] # We take random slices fron the center of the brain\n", "crop_shape = input_shape + [1]\n", @@ -257,7 +277,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 9, "id": "98d14e75", "metadata": {}, "outputs": [ @@ -270,7 +290,7 @@ }, { "data": { - "image/png": "\n", + "image/png": "\n", "text/plain": [ "
" ] @@ -300,7 +320,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 10, "id": "fa17d864", "metadata": {}, "outputs": [], @@ -310,7 +330,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 11, "id": "8f126b17", "metadata": {}, "outputs": [], @@ -335,7 +355,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 12, "id": "255c90c7", "metadata": {}, "outputs": [], @@ -367,7 +387,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 13, "id": "c18dbad8", "metadata": {}, "outputs": [], @@ -387,7 +407,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 14, "id": "89989c34", "metadata": {}, "outputs": [], @@ -404,7 +424,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 15, "id": "5b8b676f", "metadata": {}, "outputs": [], @@ -425,7 +445,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 16, "id": "36ea4308", "metadata": { "scrolled": false @@ -436,7 +456,9 @@ "output_type": "stream", "text": [ "The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", - "Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.\n" + "Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.\n", + "Downloading: \"https://download.pytorch.org/models/vgg16-397923af.pth\" to /home/walter/.cache/torch/hub/checkpoints/vgg16-397923af.pth\n", + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 528M/528M [00:07<00:00, 77.9MB/s]\n" ] } ], @@ -450,7 +472,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 17, "id": "3b57abd3", "metadata": {}, "outputs": [], @@ -469,7 +491,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 18, "id": "918eac0a", "metadata": { "scrolled": false @@ -486,13 +508,13 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|█████| 15/15 [00:08<00:00, 1.74it/s, kld=378, perceptual=0.395, generator=2.04, feature=0.206, discriminator=3.69]\n", - " 0%| | 0/2 [00:00" ] @@ -504,7 +526,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|████████| 2/2 [00:00<00:00, 3.18it/s, kld=252, perceptual=0.362, generator=2.1, feature=0.204, discriminator=3.61]" + "100%|███████| 2/2 [00:00<00:00, 3.21it/s, kld=197, perceptual=0.339, generator=1.93, feature=0.121, discriminator=3.94]" ] }, { @@ -519,32 +541,31 @@ "output_type": "stream", "text": [ "\n", - "100%|██████| 15/15 [00:06<00:00, 2.18it/s, kld=279, perceptual=0.395, generator=2.15, feature=0.202, discriminator=3.7]\n", - " 0%| | 0/2 [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2/100\n" + ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████| 2/2 [00:00<00:00, 3.25it/s, kld=184, perceptual=0.421, generator=2.2, feature=0.196, discriminator=3.36]" + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.15it/s, kld=155, perceptual=0.238, generator=2.41, feature=0.147, discriminator=3.03]\n", + "100%|████████| 2/2 [00:00<00:00, 3.92it/s, kld=179, perceptual=0.34, generator=2.41, feature=0.183, discriminator=2.95]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 2/100\n" + "Epoch 3/100\n" ] }, { @@ -552,32 +573,31 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.14it/s, kld=207, perceptual=0.421, generator=2.32, feature=0.214, discriminator=2.83]\n", - " 0%| | 0/2 [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4/100\n" + ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████| 2/2 [00:00<00:00, 3.47it/s, kld=125, perceptual=0.414, generator=2.33, feature=0.222, discriminator=2.82]" + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.19it/s, kld=483, perceptual=0.47, generator=2.29, feature=0.238, discriminator=2.82]\n", + "100%|███████| 2/2 [00:00<00:00, 4.43it/s, kld=296, perceptual=0.449, generator=2.33, feature=0.237, discriminator=2.78]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 3/100\n" + "Epoch 5/100\n" ] }, { @@ -585,32 +605,31 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.17it/s, kld=131, perceptual=0.384, generator=2.31, feature=0.219, discriminator=2.68]\n", - " 0%| | 0/2 [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6/100\n" + ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████| 2/2 [00:00<00:00, 3.38it/s, kld=87.2, perceptual=0.403, generator=2.35, feature=0.221, discriminator=2.59]" + "\n", + "100%|████| 15/15 [00:06<00:00, 2.17it/s, kld=80.9, perceptual=0.406, generator=2.88, feature=0.245, discriminator=2.05]\n", + "100%|██████| 2/2 [00:00<00:00, 3.98it/s, kld=51.4, perceptual=0.422, generator=2.89, feature=0.249, discriminator=1.97]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 4/100\n" + "Epoch 7/100\n" ] }, { @@ -618,32 +637,47 @@ "output_type": "stream", "text": [ "\n", - "100%|██████| 15/15 [00:07<00:00, 2.12it/s, kld=170, perceptual=0.353, generator=2.46, feature=0.17, discriminator=2.51]\n", - " 0%| | 0/2 [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8/100\n" + ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████| 2/2 [00:00<00:00, 3.34it/s, kld=113, perceptual=0.339, generator=2.48, feature=0.174, discriminator=2.51]" + "\n", + "100%|███████| 15/15 [00:07<00:00, 2.14it/s, kld=130, perceptual=0.4, generator=2.91, feature=0.203, discriminator=1.89]\n", + "100%|███████| 2/2 [00:00<00:00, 3.42it/s, kld=101, perceptual=0.385, generator=2.94, feature=0.185, discriminator=2.08]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 5/100\n" + "Epoch 9/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.19it/s, kld=156, perceptual=0.352, generator=2.68, feature=0.184, discriminator=2.03]\n", + "100%|██████| 2/2 [00:00<00:00, 4.28it/s, kld=97.8, perceptual=0.343, generator=2.63, feature=0.183, discriminator=2.08]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10/100\n" ] }, { @@ -651,13 +685,13 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.15it/s, kld=181, perceptual=0.321, generator=2.52, feature=0.203, discriminator=2.23]\n", - " 0%| | 0/2 [00:00" ] @@ -669,14 +703,14 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|███████| 2/2 [00:00<00:00, 2.59it/s, kld=151, perceptual=0.348, generator=2.49, feature=0.198, discriminator=2.25]" + "100%|███████| 2/2 [00:00<00:00, 3.39it/s, kld=149, perceptual=0.372, generator=2.91, feature=0.186, discriminator=1.79]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 6/100\n" + "Epoch 11/100\n" ] }, { @@ -684,32 +718,31 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.17it/s, kld=179, perceptual=0.365, generator=1.58, feature=0.125, discriminator=3.83]\n", - " 0%| | 0/2 [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12/100\n" + ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████| 2/2 [00:00<00:00, 3.08it/s, kld=108, perceptual=0.384, generator=1.93, feature=0.183, discriminator=3.43]" + "\n", + "100%|█| 15/15 [00:06<00:00, 2.23it/s, kld=1.21e+3, perceptual=0.379, generator=2.97, feature=0.186, discriminator=1.68]\n", + "100%|██████████| 2/2 [00:00<00:00, 4.23it/s, kld=817, perceptual=0.401, generator=3, feature=0.189, discriminator=1.66]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 7/100\n" + "Epoch 13/100\n" ] }, { @@ -717,32 +750,111 @@ "output_type": "stream", "text": [ "\n", - "100%|██████| 15/15 [00:06<00:00, 2.15it/s, kld=407, perceptual=0.38, generator=1.84, feature=0.127, discriminator=3.41]\n", - " 0%| | 0/2 [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14/100\n" + ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████| 2/2 [00:00<00:00, 3.35it/s, kld=302, perceptual=0.393, generator=2.02, feature=0.133, discriminator=2.79]" + "\n", + "100%|█| 15/15 [00:06<00:00, 2.24it/s, kld=2.05e+3, perceptual=0.386, generator=2.48, feature=0.177, discriminator=2.18]\n", + "100%|███| 2/2 [00:00<00:00, 4.27it/s, kld=1.31e+3, perceptual=0.326, generator=2.54, feature=0.166, discriminator=2.24]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 8/100\n" + "Epoch 15/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█| 15/15 [00:06<00:00, 2.18it/s, kld=1.19e+3, perceptual=0.337, generator=2.82, feature=0.173, discriminator=1.79]\n", + "100%|████████| 2/2 [00:00<00:00, 4.00it/s, kld=730, perceptual=0.329, generator=2.82, feature=0.17, discriminator=1.78]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.18it/s, kld=383, perceptual=0.348, generator=2.8, feature=0.144, discriminator=1.78]\n", + "100%|████████| 2/2 [00:00<00:00, 3.79it/s, kld=204, perceptual=0.328, generator=2.87, feature=0.146, discriminator=1.8]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.22it/s, kld=327, perceptual=0.349, generator=2.62, feature=0.0884, discriminator=3.06]\n", + "100%|██████| 2/2 [00:00<00:00, 4.21it/s, kld=198, perceptual=0.337, generator=2.67, feature=0.0891, discriminator=3.01]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.17it/s, kld=780, perceptual=0.363, generator=2.98, feature=0.0928, discriminator=2.89]\n", + "100%|██████| 2/2 [00:00<00:00, 3.25it/s, kld=522, perceptual=0.379, generator=2.94, feature=0.0919, discriminator=2.82]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=711, perceptual=0.346, generator=1.89, feature=0.0815, discriminator=3.88]\n", + "100%|██████| 2/2 [00:00<00:00, 4.29it/s, kld=413, perceptual=0.327, generator=2.01, feature=0.0848, discriminator=3.68]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20/100\n" ] }, { @@ -750,13 +862,13 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.15it/s, kld=356, perceptual=0.365, generator=2.34, feature=0.166, discriminator=2.39]\n", - " 0%| | 0/2 [00:00" ] @@ -768,14 +880,62 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|████████| 2/2 [00:00<00:00, 3.20it/s, kld=237, perceptual=0.37, generator=2.35, feature=0.176, discriminator=2.37]" + "100%|███████| 2/2 [00:00<00:00, 3.01it/s, kld=203, perceptual=0.336, generator=1.99, feature=0.078, discriminator=3.71]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 9/100\n" + "Epoch 21/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=174, perceptual=0.353, generator=2.58, feature=0.0856, discriminator=3.15]\n", + "100%|██████| 2/2 [00:00<00:00, 4.24it/s, kld=114, perceptual=0.274, generator=2.73, feature=0.0881, discriminator=3.83]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.21it/s, kld=195, perceptual=0.36, generator=2.49, feature=0.106, discriminator=3.85]\n", + "100%|███████| 2/2 [00:00<00:00, 3.47it/s, kld=141, perceptual=0.343, generator=2.38, feature=0.0823, discriminator=3.8]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=407, perceptual=0.355, generator=2.13, feature=0.0831, discriminator=3.37]\n", + "100%|███████| 2/2 [00:00<00:00, 4.19it/s, kld=190, perceptual=0.316, generator=2.47, feature=0.108, discriminator=3.25]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24/100\n" ] }, { @@ -783,24 +943,1327 @@ "output_type": "stream", "text": [ "\n", - " 60%|████▏ | 9/15 [00:04<00:03, 1.87it/s, kld=194, perceptual=0.35, generator=2.63, feature=0.156, discriminator=2.57]\n" + "100%|█████| 15/15 [00:06<00:00, 2.18it/s, kld=380, perceptual=0.375, generator=1.69, feature=0.108, discriminator=3.62]\n", + "100%|███████| 2/2 [00:00<00:00, 4.01it/s, kld=234, perceptual=0.348, generator=1.86, feature=0.103, discriminator=3.33]" ] }, { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)", - "Cell \u001B[0;32mIn[22], line 36\u001B[0m\n\u001B[1;32m 34\u001B[0m optimizer_D\u001B[38;5;241m.\u001B[39mzero_grad()\n\u001B[1;32m 35\u001B[0m loss_d \u001B[38;5;241m=\u001B[39m loss_d_r \u001B[38;5;241m+\u001B[39m loss_g_f\n\u001B[0;32m---> 36\u001B[0m \u001B[43mloss_d\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 37\u001B[0m optimizer_D\u001B[38;5;241m.\u001B[39mstep()\n\u001B[1;32m 39\u001B[0m \u001B[38;5;66;03m# Store\u001B[39;00m\n", - "File \u001B[0;32m~/PycharmProjects/GenerativeModels/venv/lib/python3.9/site-packages/torch/_tensor.py:479\u001B[0m, in \u001B[0;36mTensor.backward\u001B[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001B[0m\n\u001B[1;32m 432\u001B[0m \u001B[38;5;250m\u001B[39m\u001B[38;5;124mr\u001B[39m\u001B[38;5;124;03m\"\"\"Computes the gradient of current tensor w.r.t. graph leaves.\u001B[39;00m\n\u001B[1;32m 433\u001B[0m \n\u001B[1;32m 434\u001B[0m \u001B[38;5;124;03mThe graph is differentiated using the chain rule. If the tensor is\u001B[39;00m\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 476\u001B[0m \u001B[38;5;124;03m used to compute the attr::tensors.\u001B[39;00m\n\u001B[1;32m 477\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m 478\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_torch_function_unary(\u001B[38;5;28mself\u001B[39m):\n\u001B[0;32m--> 479\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mhandle_torch_function\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 480\u001B[0m \u001B[43m \u001B[49m\u001B[43mTensor\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 481\u001B[0m \u001B[43m \u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 482\u001B[0m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\n\u001B[1;32m 483\u001B[0m \u001B[43m \u001B[49m\u001B[43mgradient\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mgradient\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 484\u001B[0m \u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 485\u001B[0m \u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 486\u001B[0m \u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minputs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 487\u001B[0m \u001B[43m \u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 488\u001B[0m torch\u001B[38;5;241m.\u001B[39mautograd\u001B[38;5;241m.\u001B[39mbackward(\n\u001B[1;32m 489\u001B[0m \u001B[38;5;28mself\u001B[39m, gradient, retain_graph, create_graph, inputs\u001B[38;5;241m=\u001B[39minputs\n\u001B[1;32m 490\u001B[0m )\n", - "File \u001B[0;32m~/PycharmProjects/GenerativeModels/venv/lib/python3.9/site-packages/torch/overrides.py:1534\u001B[0m, in \u001B[0;36mhandle_torch_function\u001B[0;34m(public_api, relevant_args, *args, **kwargs)\u001B[0m\n\u001B[1;32m 1528\u001B[0m warnings\u001B[38;5;241m.\u001B[39mwarn(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mDefining your `__torch_function__ as a plain method is deprecated and \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m 1529\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mwill be an error in future, please define it as a classmethod.\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[1;32m 1530\u001B[0m \u001B[38;5;167;01mDeprecationWarning\u001B[39;00m)\n\u001B[1;32m 1532\u001B[0m \u001B[38;5;66;03m# Use `public_api` instead of `implementation` so __torch_function__\u001B[39;00m\n\u001B[1;32m 1533\u001B[0m \u001B[38;5;66;03m# implementations can do equality/identity comparisons.\u001B[39;00m\n\u001B[0;32m-> 1534\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[43mtorch_func_method\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpublic_api\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtypes\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 1536\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m result \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28mNotImplemented\u001B[39m:\n\u001B[1;32m 1537\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m result\n", - "File \u001B[0;32m~/PycharmProjects/GenerativeModels/venv/lib/python3.9/site-packages/monai/data/meta_tensor.py:276\u001B[0m, in \u001B[0;36mMetaTensor.__torch_function__\u001B[0;34m(cls, func, types, args, kwargs)\u001B[0m\n\u001B[1;32m 274\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m kwargs \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m 275\u001B[0m kwargs \u001B[38;5;241m=\u001B[39m {}\n\u001B[0;32m--> 276\u001B[0m ret \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43msuper\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m__torch_function__\u001B[49m\u001B[43m(\u001B[49m\u001B[43mfunc\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtypes\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 277\u001B[0m \u001B[38;5;66;03m# if `out` has been used as argument, metadata is not copied, nothing to do.\u001B[39;00m\n\u001B[1;32m 278\u001B[0m \u001B[38;5;66;03m# if \"out\" in kwargs:\u001B[39;00m\n\u001B[1;32m 279\u001B[0m \u001B[38;5;66;03m# return ret\u001B[39;00m\n\u001B[1;32m 280\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m _not_requiring_metadata(ret):\n", - "File \u001B[0;32m~/PycharmProjects/GenerativeModels/venv/lib/python3.9/site-packages/torch/_tensor.py:1279\u001B[0m, in \u001B[0;36mTensor.__torch_function__\u001B[0;34m(cls, func, types, args, kwargs)\u001B[0m\n\u001B[1;32m 1276\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mNotImplemented\u001B[39m\n\u001B[1;32m 1278\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m _C\u001B[38;5;241m.\u001B[39mDisableTorchFunction():\n\u001B[0;32m-> 1279\u001B[0m ret \u001B[38;5;241m=\u001B[39m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 1280\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m func \u001B[38;5;129;01min\u001B[39;00m get_default_nowrap_functions():\n\u001B[1;32m 1281\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m ret\n", - "File \u001B[0;32m~/PycharmProjects/GenerativeModels/venv/lib/python3.9/site-packages/torch/_tensor.py:488\u001B[0m, in \u001B[0;36mTensor.backward\u001B[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001B[0m\n\u001B[1;32m 478\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_torch_function_unary(\u001B[38;5;28mself\u001B[39m):\n\u001B[1;32m 479\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m handle_torch_function(\n\u001B[1;32m 480\u001B[0m Tensor\u001B[38;5;241m.\u001B[39mbackward,\n\u001B[1;32m 481\u001B[0m (\u001B[38;5;28mself\u001B[39m,),\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 486\u001B[0m inputs\u001B[38;5;241m=\u001B[39minputs,\n\u001B[1;32m 487\u001B[0m )\n\u001B[0;32m--> 488\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mautograd\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 489\u001B[0m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgradient\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minputs\u001B[49m\n\u001B[1;32m 490\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[0;32m~/PycharmProjects/GenerativeModels/venv/lib/python3.9/site-packages/torch/autograd/__init__.py:197\u001B[0m, in \u001B[0;36mbackward\u001B[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001B[0m\n\u001B[1;32m 192\u001B[0m retain_graph \u001B[38;5;241m=\u001B[39m create_graph\n\u001B[1;32m 194\u001B[0m \u001B[38;5;66;03m# The reason we repeat same the comment below is that\u001B[39;00m\n\u001B[1;32m 195\u001B[0m \u001B[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001B[39;00m\n\u001B[1;32m 196\u001B[0m \u001B[38;5;66;03m# calls in the traceback and some print out the last line\u001B[39;00m\n\u001B[0;32m--> 197\u001B[0m \u001B[43mVariable\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_execution_engine\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun_backward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001B[39;49;00m\n\u001B[1;32m 198\u001B[0m \u001B[43m \u001B[49m\u001B[43mtensors\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgrad_tensors_\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 199\u001B[0m \u001B[43m \u001B[49m\u001B[43mallow_unreachable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maccumulate_grad\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m)\u001B[49m\n", - "\u001B[0;31mKeyboardInterrupt\u001B[0m: " + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.19it/s, kld=207, perceptual=0.345, generator=1.8, feature=0.0924, discriminator=3.78]\n", + "100%|█████████| 2/2 [00:00<00:00, 4.41it/s, kld=129, perceptual=0.348, generator=1.97, feature=0.1, discriminator=3.72]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.21it/s, kld=294, perceptual=0.334, generator=2.76, feature=0.119, discriminator=2.83]\n", + "100%|████████| 2/2 [00:00<00:00, 4.38it/s, kld=211, perceptual=0.355, generator=2.52, feature=0.123, discriminator=2.9]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.20it/s, kld=384, perceptual=0.361, generator=2.86, feature=0.126, discriminator=2.7]\n", + "100%|███████| 2/2 [00:00<00:00, 3.77it/s, kld=164, perceptual=0.341, generator=3.05, feature=0.147, discriminator=2.77]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.22it/s, kld=257, perceptual=0.354, generator=1.4, feature=0.0816, discriminator=3.77]\n", + "100%|███████| 2/2 [00:00<00:00, 4.09it/s, kld=173, perceptual=0.356, generator=1.36, feature=0.0792, discriminator=3.7]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.17it/s, kld=166, perceptual=0.366, generator=1.22, feature=0.0643, discriminator=3.9]\n", + "100%|██████| 2/2 [00:00<00:00, 3.98it/s, kld=92.5, perceptual=0.314, generator=1.46, feature=0.065, discriminator=3.55]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 30/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=171, perceptual=0.331, generator=1.94, feature=0.0693, discriminator=3.74]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████| 2/2 [00:00<00:00, 3.36it/s, kld=143, perceptual=0.311, generator=1.89, feature=0.0648, discriminator=3.33]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 31/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.16it/s, kld=264, perceptual=0.363, generator=1.98, feature=0.0766, discriminator=3.99]\n", + "100%|██████| 2/2 [00:00<00:00, 3.75it/s, kld=178, perceptual=0.348, generator=1.96, feature=0.0681, discriminator=3.98]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 32/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|███| 15/15 [00:06<00:00, 2.20it/s, kld=85.9, perceptual=0.308, generator=2.05, feature=0.0725, discriminator=3.94]\n", + "100%|██████| 2/2 [00:00<00:00, 3.43it/s, kld=55.1, perceptual=0.327, generator=2.17, feature=0.113, discriminator=3.84]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 33/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.20it/s, kld=110, perceptual=0.293, generator=2.17, feature=0.0944, discriminator=3.92]\n", + "100%|██████| 2/2 [00:00<00:00, 3.94it/s, kld=110, perceptual=0.329, generator=2.01, feature=0.0707, discriminator=4.03]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 34/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.26it/s, kld=203, perceptual=0.329, generator=2.42, feature=0.0917, discriminator=3.65]\n", + "100%|███████| 2/2 [00:00<00:00, 2.71it/s, kld=106, perceptual=0.29, generator=2.39, feature=0.0954, discriminator=3.58]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 35/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=282, perceptual=0.28, generator=2.44, feature=0.0842, discriminator=3.62]\n", + "100%|██████| 2/2 [00:00<00:00, 4.21it/s, kld=206, perceptual=0.284, generator=2.42, feature=0.0789, discriminator=3.64]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 36/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.20it/s, kld=313, perceptual=0.333, generator=2.3, feature=0.0794, discriminator=4.01]\n", + "100%|███████| 2/2 [00:00<00:00, 4.14it/s, kld=192, perceptual=0.285, generator=2.4, feature=0.0781, discriminator=4.01]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 37/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=101, perceptual=0.297, generator=2.4, feature=0.0699, discriminator=3.96]\n", + "100%|█████| 2/2 [00:00<00:00, 3.69it/s, kld=56.2, perceptual=0.279, generator=2.44, feature=0.0671, discriminator=3.96]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 38/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|███| 15/15 [00:06<00:00, 2.17it/s, kld=94.7, perceptual=0.296, generator=2.37, feature=0.0693, discriminator=3.91]\n", + "100%|███████| 2/2 [00:00<00:00, 4.23it/s, kld=64, perceptual=0.239, generator=2.42, feature=0.0572, discriminator=3.95]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 39/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.22it/s, kld=154, perceptual=0.293, generator=2.28, feature=0.055, discriminator=3.9]\n", + "100%|███████| 2/2 [00:00<00:00, 3.99it/s, kld=93.5, perceptual=0.27, generator=2.25, feature=0.0545, discriminator=3.9]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 40/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.20it/s, kld=268, perceptual=0.227, generator=2.26, feature=0.052, discriminator=3.98]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████| 2/2 [00:00<00:00, 3.34it/s, kld=247, perceptual=0.269, generator=2.15, feature=0.056, discriminator=3.95]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 41/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=195, perceptual=0.231, generator=2.14, feature=0.054, discriminator=4.01]\n", + "100%|██████| 2/2 [00:00<00:00, 4.25it/s, kld=128, perceptual=0.238, generator=2.14, feature=0.0585, discriminator=4.02]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 42/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.18it/s, kld=312, perceptual=0.238, generator=2.29, feature=0.0563, discriminator=3.91]\n", + "100%|███████| 2/2 [00:00<00:00, 3.29it/s, kld=217, perceptual=0.24, generator=2.29, feature=0.0595, discriminator=3.89]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 43/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=653, perceptual=0.225, generator=2.5, feature=0.0637, discriminator=3.79]\n", + "100%|██████| 2/2 [00:00<00:00, 4.31it/s, kld=390, perceptual=0.256, generator=2.43, feature=0.0642, discriminator=3.72]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 44/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.21it/s, kld=860, perceptual=0.261, generator=2.38, feature=0.0658, discriminator=3.85]\n", + "100%|██████| 2/2 [00:00<00:00, 3.70it/s, kld=513, perceptual=0.241, generator=2.41, feature=0.0746, discriminator=3.82]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 45/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.22it/s, kld=422, perceptual=0.282, generator=2.38, feature=0.0676, discriminator=3.99]\n", + "100%|███████| 2/2 [00:00<00:00, 4.23it/s, kld=325, perceptual=0.234, generator=2.5, feature=0.0701, discriminator=4.01]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 46/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█| 15/15 [00:06<00:00, 2.19it/s, kld=2.07e+3, perceptual=0.257, generator=2.42, feature=0.0706, discriminator=3.89\n", + "100%|██| 2/2 [00:00<00:00, 3.85it/s, kld=1.21e+3, perceptual=0.247, generator=2.43, feature=0.0692, discriminator=3.88]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 47/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.18it/s, kld=338, perceptual=0.263, generator=2.39, feature=0.0861, discriminator=4.03]\n", + "100%|███████| 2/2 [00:00<00:00, 4.18it/s, kld=248, perceptual=0.246, generator=2.5, feature=0.0764, discriminator=3.86]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 48/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=331, perceptual=0.246, generator=2.47, feature=0.0839, discriminator=3.71]\n", + "100%|███████| 2/2 [00:00<00:00, 4.31it/s, kld=205, perceptual=0.243, generator=2.39, feature=0.088, discriminator=3.79]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 49/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██| 15/15 [00:06<00:00, 2.19it/s, kld=3.41e+3, perceptual=0.248, generator=2.58, feature=0.12, discriminator=3.76]\n", + "100%|████| 2/2 [00:00<00:00, 3.71it/s, kld=2.19e+3, perceptual=0.26, generator=2.55, feature=0.125, discriminator=3.59]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 50/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█| 15/15 [00:06<00:00, 2.21it/s, kld=1.09e+3, perceptual=0.246, generator=2.76, feature=0.128, discriminator=3.32]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████| 2/2 [00:00<00:00, 3.31it/s, kld=704, perceptual=0.241, generator=2.83, feature=0.145, discriminator=3.18]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 51/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█| 15/15 [00:06<00:00, 2.19it/s, kld=1.35e+3, perceptual=0.247, generator=1.95, feature=0.0697, discriminator=3.94\n", + "100%|██████| 2/2 [00:00<00:00, 4.27it/s, kld=891, perceptual=0.269, generator=1.84, feature=0.0785, discriminator=4.04]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 52/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.25it/s, kld=373, perceptual=0.286, generator=1.98, feature=0.0785, discriminator=3.87]\n", + "100%|███████| 2/2 [00:00<00:00, 3.92it/s, kld=274, perceptual=0.285, generator=2.02, feature=0.0876, discriminator=3.8]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 53/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.16it/s, kld=637, perceptual=0.264, generator=2.5, feature=0.0888, discriminator=3.63]\n", + "100%|██████| 2/2 [00:00<00:00, 3.50it/s, kld=417, perceptual=0.268, generator=2.46, feature=0.0777, discriminator=3.49]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 54/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.23it/s, kld=584, perceptual=0.274, generator=2.46, feature=0.095, discriminator=3.23]\n", + "100%|██████| 2/2 [00:00<00:00, 4.24it/s, kld=420, perceptual=0.254, generator=2.53, feature=0.0873, discriminator=3.37]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 55/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.19it/s, kld=842, perceptual=0.272, generator=1.82, feature=0.0686, discriminator=3.99]\n", + "100%|██████| 2/2 [00:00<00:00, 4.11it/s, kld=521, perceptual=0.256, generator=1.91, feature=0.0724, discriminator=3.97]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 56/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.20it/s, kld=301, perceptual=0.292, generator=2.01, feature=0.0822, discriminator=3.88]\n", + "100%|███████| 2/2 [00:00<00:00, 4.03it/s, kld=169, perceptual=0.291, generator=2.27, feature=0.103, discriminator=3.59]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 57/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.15it/s, kld=835, perceptual=0.302, generator=2.35, feature=0.104, discriminator=3.89]\n", + "100%|█████████| 2/2 [00:00<00:00, 4.33it/s, kld=517, perceptual=0.341, generator=2.3, feature=0.119, discriminator=3.6]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 58/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.18it/s, kld=425, perceptual=0.322, generator=2.7, feature=0.174, discriminator=3.03]\n", + "100%|███████| 2/2 [00:00<00:00, 4.24it/s, kld=245, perceptual=0.326, generator=2.72, feature=0.162, discriminator=3.07]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 59/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=139, perceptual=0.302, generator=2.79, feature=0.149, discriminator=2.91]\n", + "100%|██████| 2/2 [00:00<00:00, 4.20it/s, kld=69.9, perceptual=0.284, generator=2.83, feature=0.137, discriminator=2.89]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 60/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.20it/s, kld=210, perceptual=0.278, generator=2.84, feature=0.133, discriminator=2.84]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████| 2/2 [00:00<00:00, 3.44it/s, kld=132, perceptual=0.257, generator=2.86, feature=0.12, discriminator=2.85]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 61/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.19it/s, kld=835, perceptual=0.266, generator=1.8, feature=0.0844, discriminator=3.96]\n", + "100%|██████| 2/2 [00:00<00:00, 3.26it/s, kld=767, perceptual=0.293, generator=1.72, feature=0.0979, discriminator=3.83]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 62/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.19it/s, kld=908, perceptual=0.306, generator=1.62, feature=0.102, discriminator=3.5]\n", + "100%|███████| 2/2 [00:00<00:00, 4.25it/s, kld=568, perceptual=0.324, generator=1.66, feature=0.107, discriminator=3.45]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 63/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=865, perceptual=0.304, generator=1.54, feature=0.0938, discriminator=3.68]\n", + "100%|██████| 2/2 [00:00<00:00, 3.77it/s, kld=668, perceptual=0.295, generator=1.62, feature=0.0881, discriminator=3.65]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 64/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.19it/s, kld=915, perceptual=0.272, generator=1.26, feature=0.074, discriminator=3.78]\n", + "100%|██████| 2/2 [00:00<00:00, 4.09it/s, kld=570, perceptual=0.299, generator=1.21, feature=0.0785, discriminator=3.76]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 65/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.20it/s, kld=343, perceptual=0.268, generator=1.28, feature=0.0697, discriminator=3.87]\n", + "100%|███████| 2/2 [00:00<00:00, 4.37it/s, kld=209, perceptual=0.278, generator=1.2, feature=0.0735, discriminator=3.81]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 66/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.25it/s, kld=495, perceptual=0.252, generator=1.18, feature=0.0635, discriminator=3.82]\n", + "100%|██████| 2/2 [00:00<00:00, 3.88it/s, kld=337, perceptual=0.267, generator=1.15, feature=0.0676, discriminator=3.88]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 67/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.21it/s, kld=677, perceptual=0.271, generator=1.25, feature=0.0701, discriminator=3.95]\n", + "100%|██████| 2/2 [00:00<00:00, 4.08it/s, kld=563, perceptual=0.279, generator=1.11, feature=0.0637, discriminator=3.93]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 68/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.29it/s, kld=373, perceptual=0.283, generator=1.24, feature=0.0668, discriminator=3.91]\n", + "100%|██████| 2/2 [00:00<00:00, 3.75it/s, kld=235, perceptual=0.266, generator=1.25, feature=0.0631, discriminator=3.89]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 69/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.20it/s, kld=455, perceptual=0.275, generator=1.23, feature=0.0644, discriminator=3.96]\n", + "100%|██████| 2/2 [00:00<00:00, 4.21it/s, kld=363, perceptual=0.263, generator=1.19, feature=0.0624, discriminator=3.91]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 70/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.23it/s, kld=519, perceptual=0.264, generator=1.29, feature=0.0628, discriminator=3.9]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████| 2/2 [00:00<00:00, 3.35it/s, kld=375, perceptual=0.261, generator=1.29, feature=0.0589, discriminator=3.89]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 71/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.18it/s, kld=329, perceptual=0.242, generator=1.18, feature=0.0615, discriminator=4.02]\n", + "100%|███████| 2/2 [00:00<00:00, 3.73it/s, kld=205, perceptual=0.267, generator=1.16, feature=0.067, discriminator=3.97]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 72/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.18it/s, kld=629, perceptual=0.233, generator=1.37, feature=0.0524, discriminator=3.98]\n", + "100%|██████| 2/2 [00:00<00:00, 4.20it/s, kld=417, perceptual=0.233, generator=1.38, feature=0.0535, discriminator=3.95]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 73/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.16it/s, kld=198, perceptual=0.245, generator=1.45, feature=0.0611, discriminator=3.97]\n", + "100%|███████| 2/2 [00:00<00:00, 4.12it/s, kld=129, perceptual=0.263, generator=1.43, feature=0.071, discriminator=3.99]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 74/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.22it/s, kld=466, perceptual=0.279, generator=1.78, feature=0.0656, discriminator=3.92]\n", + "100%|██████| 2/2 [00:00<00:00, 3.96it/s, kld=308, perceptual=0.278, generator=1.79, feature=0.0681, discriminator=3.86]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 75/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.16it/s, kld=736, perceptual=0.245, generator=1.53, feature=0.0834, discriminator=4.15]\n", + "100%|██████| 2/2 [00:00<00:00, 4.08it/s, kld=668, perceptual=0.235, generator=1.41, feature=0.0861, discriminator=4.06]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 76/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.19it/s, kld=815, perceptual=0.277, generator=1.98, feature=0.129, discriminator=3.59]\n", + "100%|███████| 2/2 [00:00<00:00, 4.03it/s, kld=525, perceptual=0.268, generator=2.05, feature=0.133, discriminator=3.32]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 77/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.24it/s, kld=1e+3, perceptual=0.276, generator=2.28, feature=0.115, discriminator=3.72]\n", + "100%|███████| 2/2 [00:00<00:00, 3.66it/s, kld=660, perceptual=0.282, generator=2.32, feature=0.133, discriminator=3.53]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 78/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██| 15/15 [00:06<00:00, 2.21it/s, kld=1.36e+3, perceptual=0.279, generator=2.47, feature=0.14, discriminator=2.85]\n", + "100%|███████| 2/2 [00:00<00:00, 4.23it/s, kld=957, perceptual=0.297, generator=2.59, feature=0.138, discriminator=2.82]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 79/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█| 15/15 [00:06<00:00, 2.22it/s, kld=2.86e+3, perceptual=0.311, generator=1.84, feature=0.142, discriminator=3.38]\n", + "100%|████| 2/2 [00:00<00:00, 3.54it/s, kld=2.04e+3, perceptual=0.328, generator=1.89, feature=0.148, discriminator=3.5]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 80/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.17it/s, kld=699, perceptual=0.338, generator=2.35, feature=0.14, discriminator=2.54]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████| 2/2 [00:00<00:00, 3.51it/s, kld=409, perceptual=0.331, generator=2.58, feature=0.147, discriminator=2.5]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 81/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.25it/s, kld=491, perceptual=0.355, generator=2.74, feature=0.146, discriminator=2.93]\n", + "100%|███████| 2/2 [00:00<00:00, 2.61it/s, kld=359, perceptual=0.314, generator=2.85, feature=0.143, discriminator=2.95]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 82/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.24it/s, kld=517, perceptual=0.334, generator=2.67, feature=0.14, discriminator=3.12]\n", + "100%|████████| 2/2 [00:00<00:00, 3.93it/s, kld=384, perceptual=0.326, generator=2.72, feature=0.14, discriminator=3.09]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 83/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.22it/s, kld=1.27e+3, perceptual=0.302, generator=1.8, feature=0.0732, discriminator=4]\n", + "100%|███████| 2/2 [00:00<00:00, 3.94it/s, kld=815, perceptual=0.31, generator=1.76, feature=0.0643, discriminator=3.98]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 84/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█| 15/15 [00:06<00:00, 2.20it/s, kld=1.01e+3, perceptual=0.281, generator=1.79, feature=0.0657, discriminator=3.96\n", + "100%|███████| 2/2 [00:00<00:00, 4.20it/s, kld=590, perceptual=0.263, generator=1.86, feature=0.071, discriminator=3.95]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 85/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.18it/s, kld=465, perceptual=0.253, generator=1.94, feature=0.0742, discriminator=3.63]\n", + "100%|███████| 2/2 [00:00<00:00, 4.08it/s, kld=329, perceptual=0.251, generator=1.95, feature=0.0817, discriminator=3.5]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 86/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=862, perceptual=0.243, generator=1.09, feature=0.0589, discriminator=3.92]\n", + "100%|██████| 2/2 [00:00<00:00, 3.89it/s, kld=591, perceptual=0.238, generator=1.14, feature=0.0605, discriminator=3.81]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 87/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█| 15/15 [00:06<00:00, 2.18it/s, kld=1.08e+3, perceptual=0.24, generator=1.74, feature=0.0678, discriminator=3.54]\n", + "100%|███████| 2/2 [00:00<00:00, 4.11it/s, kld=677, perceptual=0.231, generator=1.8, feature=0.0649, discriminator=3.91]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 88/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█████| 15/15 [00:06<00:00, 2.21it/s, kld=958, perceptual=0.316, generator=1.8, feature=0.0715, discriminator=3.97]\n", + "100%|██████| 2/2 [00:00<00:00, 4.26it/s, kld=618, perceptual=0.301, generator=1.81, feature=0.0668, discriminator=3.96]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 89/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|███| 15/15 [00:06<00:00, 2.16it/s, kld=1.17e+3, perceptual=0.311, generator=1.31, feature=0.0658, discriminator=4]\n", + "100%|██████| 2/2 [00:00<00:00, 3.89it/s, kld=989, perceptual=0.305, generator=1.23, feature=0.0674, discriminator=3.95]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 90/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██| 15/15 [00:06<00:00, 2.23it/s, kld=2.17e+3, perceptual=0.294, generator=1.94, feature=0.101, discriminator=3.3]\n", + " 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████| 2/2 [00:00<00:00, 2.79it/s, kld=1.5e+3, perceptual=0.301, generator=1.96, feature=0.12, discriminator=3.16]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 91/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█| 15/15 [00:06<00:00, 2.24it/s, kld=1.88e+3, perceptual=0.29, generator=2.02, feature=0.0721, discriminator=3.34]\n", + "100%|██| 2/2 [00:00<00:00, 4.14it/s, kld=1.26e+3, perceptual=0.249, generator=1.96, feature=0.0852, discriminator=3.27]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 92/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█| 15/15 [00:06<00:00, 2.19it/s, kld=1.49e+3, perceptual=0.253, generator=1.96, feature=0.0624, discriminator=3.82\n", + "100%|██████| 2/2 [00:00<00:00, 3.80it/s, kld=862, perceptual=0.256, generator=1.96, feature=0.0589, discriminator=3.98]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 93/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████| 15/15 [00:06<00:00, 2.23it/s, kld=592, perceptual=0.252, generator=1.11, feature=0.108, discriminator=4.4]\n", + "100%|█████████| 2/2 [00:00<00:00, 3.92it/s, kld=365, perceptual=0.278, generator=1.19, feature=0.13, discriminator=4.5]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 94/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█| 15/15 [00:06<00:00, 2.21it/s, kld=1.65e+3, perceptual=0.267, generator=1.77, feature=0.121, discriminator=3.45]\n", + "100%|███| 2/2 [00:00<00:00, 4.14it/s, kld=1.05e+3, perceptual=0.244, generator=1.85, feature=0.123, discriminator=3.18]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 95/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|█| 15/15 [00:06<00:00, 2.24it/s, kld=1.53e+3, perceptual=0.279, generator=1.74, feature=0.114, discriminator=3.33]\n", + "100%|██████| 2/2 [00:00<00:00, 4.29it/s, kld=826, perceptual=0.259, generator=1.69, feature=0.0888, discriminator=3.51]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 96/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.20it/s, kld=940, perceptual=0.264, generator=1.88, feature=0.0612, discriminator=3.97]\n", + "100%|██████| 2/2 [00:00<00:00, 4.26it/s, kld=580, perceptual=0.262, generator=1.55, feature=0.0873, discriminator=4.16]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 97/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=588, perceptual=0.275, generator=1.81, feature=0.0542, discriminator=3.85]\n", + "100%|██████| 2/2 [00:00<00:00, 3.94it/s, kld=549, perceptual=0.262, generator=1.94, feature=0.0645, discriminator=3.75]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 98/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.18it/s, kld=658, perceptual=0.284, generator=1.73, feature=0.0515, discriminator=3.94]\n", + "100%|██████| 2/2 [00:00<00:00, 4.15it/s, kld=362, perceptual=0.271, generator=1.79, feature=0.0583, discriminator=3.94]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 99/100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=808, perceptual=0.281, generator=1.83, feature=0.0545, discriminator=3.86]\n", + "100%|██████| 2/2 [00:00<00:00, 2.92it/s, kld=598, perceptual=0.266, generator=1.79, feature=0.0619, discriminator=3.89]\n" ] } ], @@ -929,9 +2392,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} From 067998b90260cf86962459caa8217149456c43a1 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Wed, 5 Jul 2023 09:40:38 +0100 Subject: [PATCH 5/7] Implemented changes as per PR correction: adding docstrings, removing blank lines, modifying the notebook to take losses more into account + conclusion. --- generative/losses/kld_loss.py | 11 + generative/networks/blocks/spade_norm.py | 11 + generative/networks/nets/spade_network.py | 70 ++- .../2d_spade_gan/2d_spade_vae.ipynb | 586 ++++++++++-------- .../generative/2d_spade_gan/2d_spade_vae.py | 26 +- 5 files changed, 421 insertions(+), 283 deletions(-) diff --git a/generative/losses/kld_loss.py b/generative/losses/kld_loss.py index 4b7e6f31..d3178b84 100644 --- a/generative/losses/kld_loss.py +++ b/generative/losses/kld_loss.py @@ -1,3 +1,14 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch.nn as nn import torch diff --git a/generative/networks/blocks/spade_norm.py b/generative/networks/blocks/spade_norm.py index 68991f0c..419f740e 100644 --- a/generative/networks/blocks/spade_norm.py +++ b/generative/networks/blocks/spade_norm.py @@ -1,3 +1,14 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations import torch import torch.nn as nn diff --git a/generative/networks/nets/spade_network.py b/generative/networks/nets/spade_network.py index ebf601e6..0be1d9c8 100644 --- a/generative/networks/nets/spade_network.py +++ b/generative/networks/nets/spade_network.py @@ -26,6 +26,17 @@ class UpsamplingModes(StrEnum): bilinear = "bilinear" class SPADE_ResNetBlock(nn.Module): + """ + Creates a Residual Block with SPADE normalisation. + Args: + spatial_dims: number of spatial dimensions + in_channels: number of input channels + out_channels: number of output channels + label_nc: number of semantic channels that will be taken into account in SPADE normalisation blocks + spade_intermediate_channels: number of intermediate channels in the middle conv. layers in SPADE normalisation blocks + norm: base normalisation type used on top of SPADE + kernel_size: convolutional kernel size + """ def __init__(self, spatial_dims: int, @@ -74,7 +85,6 @@ def __init__(self, norm=norm) def forward(self, x, seg): - x_s = self.shortcut(x, seg) dx = self.conv_0(self.activation(self.norm_0(x, seg))) dx = self.conv_1(self.activation(self.norm_1(dx, seg))) @@ -89,7 +99,19 @@ def shortcut(self, x, seg): return x_s class SPADE_Encoder(nn.Module): - + """ + Encoding branch of a VAE compatible with a SPADE-like generator + Args: + spatial_dims: number of spatial dimensions + in_channels: number of input channels + z_dim: latent space dimension of the VAE containing the image sytle information + num_channels: number of output after each downsampling block + input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers + of the autoencoder (HxWx[D]) + kernel_size: convolutional kernel size + norm: normalisation layer type + act: activation type + """ def __init__(self, spatial_dims: int, in_channels: int, @@ -139,7 +161,6 @@ def forward(self, x,): return mu, logvar def encode(self, x): - for block in self.blocks: x = block(x) x = x.view(x.size(0), -1) @@ -154,7 +175,25 @@ def reparameterize(self, mu, logvar): return eps.mul(std) + mu class SPADE_Decoder(nn.Module): - + """ + Decoder branch of a SPADE-like generator. It can be used independently, without an encoding branch, + behaving like a GAN, or coupled to a SPADE encoder. + Args: + label_nc: number of semantic labels + spatial_dims: number of spatial dimensions + out_channels: number of output channels + label_nc: number of semantic channels used for the SPADE normalisation blocks + input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers + num_channels: number of output after each downsampling block + z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used) + is_gan: whether the decoder is going to be coupled to an autoencoder or not (true: not, false: yes) + spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks + norm: base normalisation type + act: activation layer type + last_act: activation layer type for the last layer of the network (can differ from previous) + kernel_size: convolutional kernel size + upsampling_mode: upsampling mode (nearest, bilinear etc.) + """ def __init__(self, spatial_dims: int, out_channels: int, @@ -214,8 +253,6 @@ def __init__(self, def forward(self, seg, z: torch.Tensor = None): - - if self.is_gan: x = F.interpolate(seg, size=tuple(self.latent_spatial_shape)) x = self.fc(x) @@ -235,6 +272,26 @@ def forward(self, seg, z: torch.Tensor = None): class SPADE_Net(nn.Module): + """ + SPADE Network, implemented based on the code by Park, T et al. in "Semantic Image Synthesis with Spatially-Adaptive Normalization" + (https://github.com/NVlabs/SPADE) + Args: + spatial_dims: number of spatial dimensions + in_channels: number of input channels + out_channels: number of output channels + label_nc: number of semantic channels used for the SPADE normalisation blocks + input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers + num_channels: number of output after each downsampling block + z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used) + is_vae: whether the decoder is going to be coupled to an autoencoder (true) or not (false) + spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks + norm: base normalisation type + act: activation layer type + last_act: activation layer type for the last layer of the network (can differ from previous) + kernel_size: convolutional kernel size + upsampling_mode: upsampling mode (nearest, bilinear etc.) + """ + def __init__( self, spatial_dims: int, @@ -297,7 +354,6 @@ def __init__( ) def forward(self, seg: torch.Tensor, x: Union[torch.Tensor, None] = None): - z = None if self.is_vae: z_mu, z_logvar = self.encoder(x) diff --git a/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb b/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb index bf7271c3..e5c64a3d 100644 --- a/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb +++ b/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb @@ -37,23 +37,12 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "e059c423", "metadata": { "scrolled": true }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "A matching Triton is not available, some optimizations will not be enabled.\n", - "Error caught was: No module named 'triton'\n" - ] - } - ], + "outputs": [], "source": [ "import os\n", "import tempfile\n", @@ -74,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "e76296e7", "metadata": { "scrolled": false @@ -84,7 +73,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Temporary directory used: /tmp/tmpo8gppqh6 \n" + "Temporary directory used: /tmp/tmpz0rbc_3s \n" ] } ], @@ -97,7 +86,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 26, "id": "2483148a", "metadata": {}, "outputs": [], @@ -145,7 +134,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "dc560f7e", "metadata": { "scrolled": true @@ -157,17 +146,17 @@ "text": [ "Downloading...\n", "From: https://drive.google.com/uc?export=download&id=1SX_MCzQe-vyq09QYxECk32wZ2vxp9rx5\n", - "To: /tmp/tmpo8gppqh6/data.zip\n", - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 384M/384M [00:05<00:00, 69.6MB/s]\n" + "To: /tmp/tmpz0rbc_3s/data.zip\n", + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 384M/384M [00:05<00:00, 67.0MB/s]\n" ] }, { "data": { "text/plain": [ - "'/tmp/tmpo8gppqh6/data.zip'" + "'/tmp/tmpz0rbc_3s/data.zip'" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -179,7 +168,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "cd7dd6ec", "metadata": {}, "outputs": [], @@ -193,7 +182,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "d48987b9", "metadata": {}, "outputs": [], @@ -219,21 +208,12 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "id": "8ab79cdc", "metadata": { "lines_to_end_of_cell_marker": 2 }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "monai.transforms.io.dictionary LoadImaged.__init__:image_only: Current default value of argument `image_only=False` has been deprecated since version 1.1. It will be changed to `image_only=True` in version 1.3.\n", - "monai.transforms.croppad.dictionary RandSpatialCropd.__init__:random_size: Current default value of argument `random_size=True` has been deprecated since version 1.1. It will be changed to `random_size=False` in version 1.3.\n" - ] - } - ], + "outputs": [], "source": [ "preliminar_shape = input_shape + [50] # We take random slices fron the center of the brain\n", "crop_shape = input_shape + [1]\n", @@ -277,7 +257,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "98d14e75", "metadata": {}, "outputs": [ @@ -290,7 +270,7 @@ }, { "data": { - "image/png": "\n", + "image/png": "\n", "text/plain": [ "
" ] @@ -320,7 +300,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "id": "fa17d864", "metadata": {}, "outputs": [], @@ -330,7 +310,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "id": "8f126b17", "metadata": {}, "outputs": [], @@ -355,7 +335,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "id": "255c90c7", "metadata": {}, "outputs": [], @@ -387,7 +367,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "id": "c18dbad8", "metadata": {}, "outputs": [], @@ -407,7 +387,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "id": "89989c34", "metadata": {}, "outputs": [], @@ -424,7 +404,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "id": "5b8b676f", "metadata": {}, "outputs": [], @@ -445,7 +425,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "id": "36ea4308", "metadata": { "scrolled": false @@ -456,9 +436,7 @@ "output_type": "stream", "text": [ "The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", - "Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.\n", - "Downloading: \"https://download.pytorch.org/models/vgg16-397923af.pth\" to /home/walter/.cache/torch/hub/checkpoints/vgg16-397923af.pth\n", - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 528M/528M [00:07<00:00, 77.9MB/s]\n" + "Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.\n" ] } ], @@ -472,7 +450,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "id": "3b57abd3", "metadata": {}, "outputs": [], @@ -491,7 +469,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 27, "id": "918eac0a", "metadata": { "scrolled": false @@ -508,13 +486,13 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|█████| 15/15 [00:16<00:00, 1.12s/it, kld=293, perceptual=0.339, generator=1.94, feature=0.125, discriminator=3.96]\n", - " 0%| | 0/2 [00:00" ] @@ -526,7 +504,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|███████| 2/2 [00:00<00:00, 3.21it/s, kld=197, perceptual=0.339, generator=1.93, feature=0.121, discriminator=3.94]" + "100%|██████| 2/2 [00:00<00:00, 3.22it/s, kld=63.6, perceptual=0.254, generator=3.38, feature=0.125, discriminator=1.77]" ] }, { @@ -541,8 +519,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.17it/s, kld=162, perceptual=0.361, generator=1.99, feature=0.127, discriminator=3.79]\n", - "100%|███████| 2/2 [00:00<00:00, 4.02it/s, kld=116, perceptual=0.324, generator=2.09, feature=0.131, discriminator=3.63]" + "100%|█████| 15/15 [00:06<00:00, 2.26it/s, kld=754, perceptual=0.354, generator=3.46, feature=0.168, discriminator=1.79]\n", + "100%|████████| 2/2 [00:00<00:00, 4.36it/s, kld=359, perceptual=0.36, generator=3.44, feature=0.173, discriminator=1.85]" ] }, { @@ -557,8 +535,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.15it/s, kld=155, perceptual=0.238, generator=2.41, feature=0.147, discriminator=3.03]\n", - "100%|████████| 2/2 [00:00<00:00, 3.92it/s, kld=179, perceptual=0.34, generator=2.41, feature=0.183, discriminator=2.95]" + "100%|█████| 15/15 [00:06<00:00, 2.17it/s, kld=298, perceptual=0.324, generator=3.57, feature=0.167, discriminator=1.58]\n", + "100%|████████| 2/2 [00:00<00:00, 4.07it/s, kld=346, perceptual=0.34, generator=3.57, feature=0.167, discriminator=1.63]" ] }, { @@ -573,8 +551,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.19it/s, kld=199, perceptual=0.286, generator=2.31, feature=0.164, discriminator=3.36]\n", - "100%|█████████| 2/2 [00:00<00:00, 4.28it/s, kld=144, perceptual=0.3, generator=2.53, feature=0.179, discriminator=2.68]" + "100%|███████| 15/15 [00:06<00:00, 2.27it/s, kld=313, perceptual=0.32, generator=3.57, feature=0.166, discriminator=1.6]\n", + "100%|███████| 2/2 [00:00<00:00, 4.36it/s, kld=262, perceptual=0.325, generator=3.49, feature=0.164, discriminator=1.63]" ] }, { @@ -589,8 +567,8 @@ "output_type": "stream", "text": [ "\n", - "100%|██████| 15/15 [00:06<00:00, 2.19it/s, kld=483, perceptual=0.47, generator=2.29, feature=0.238, discriminator=2.82]\n", - "100%|███████| 2/2 [00:00<00:00, 4.43it/s, kld=296, perceptual=0.449, generator=2.33, feature=0.237, discriminator=2.78]" + "100%|█████| 15/15 [00:06<00:00, 2.28it/s, kld=213, perceptual=0.305, generator=3.52, feature=0.168, discriminator=1.67]\n", + "100%|███████| 2/2 [00:00<00:00, 4.45it/s, kld=101, perceptual=0.307, generator=3.58, feature=0.164, discriminator=1.65]" ] }, { @@ -605,8 +583,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.25it/s, kld=258, perceptual=0.456, generator=2.84, feature=0.273, discriminator=2.15]\n", - "100%|███████| 2/2 [00:00<00:00, 4.31it/s, kld=150, perceptual=0.466, generator=2.83, feature=0.275, discriminator=2.15]" + "100%|██████| 15/15 [00:06<00:00, 2.23it/s, kld=119, perceptual=0.302, generator=3.6, feature=0.165, discriminator=1.51]\n", + "100%|██████| 2/2 [00:00<00:00, 3.79it/s, kld=68.9, perceptual=0.288, generator=3.58, feature=0.165, discriminator=1.61]" ] }, { @@ -621,8 +599,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.17it/s, kld=80.9, perceptual=0.406, generator=2.88, feature=0.245, discriminator=2.05]\n", - "100%|██████| 2/2 [00:00<00:00, 3.98it/s, kld=51.4, perceptual=0.422, generator=2.89, feature=0.249, discriminator=1.97]" + "100%|██████| 15/15 [00:06<00:00, 2.21it/s, kld=114, perceptual=0.286, generator=3.59, feature=0.163, discriminator=1.5]\n", + "100%|████████| 2/2 [00:00<00:00, 4.25it/s, kld=80, perceptual=0.278, generator=3.59, feature=0.161, discriminator=1.51]" ] }, { @@ -637,8 +615,8 @@ "output_type": "stream", "text": [ "\n", - "100%|██████| 15/15 [00:06<00:00, 2.19it/s, kld=101, perceptual=0.375, generator=2.75, feature=0.194, discriminator=2.1]\n", - "100%|██████| 2/2 [00:00<00:00, 4.16it/s, kld=73.3, perceptual=0.341, generator=2.81, feature=0.162, discriminator=2.65]" + "100%|████████| 15/15 [00:06<00:00, 2.28it/s, kld=116, perceptual=0.263, generator=3.6, feature=0.16, discriminator=1.5]\n", + "100%|█████████| 2/2 [00:00<00:00, 4.44it/s, kld=73.9, perceptual=0.269, generator=3.6, feature=0.16, discriminator=1.5]" ] }, { @@ -653,8 +631,8 @@ "output_type": "stream", "text": [ "\n", - "100%|███████| 15/15 [00:07<00:00, 2.14it/s, kld=130, perceptual=0.4, generator=2.91, feature=0.203, discriminator=1.89]\n", - "100%|███████| 2/2 [00:00<00:00, 3.42it/s, kld=101, perceptual=0.385, generator=2.94, feature=0.185, discriminator=2.08]" + "100%|███████| 15/15 [00:06<00:00, 2.22it/s, kld=100, perceptual=0.25, generator=3.6, feature=0.155, discriminator=1.53]\n", + "100%|█████████| 2/2 [00:00<00:00, 4.31it/s, kld=54.1, perceptual=0.26, generator=3.6, feature=0.154, discriminator=1.5]" ] }, { @@ -669,8 +647,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.19it/s, kld=156, perceptual=0.352, generator=2.68, feature=0.184, discriminator=2.03]\n", - "100%|██████| 2/2 [00:00<00:00, 4.28it/s, kld=97.8, perceptual=0.343, generator=2.63, feature=0.183, discriminator=2.08]" + "100%|████| 15/15 [00:06<00:00, 2.19it/s, kld=97.2, perceptual=0.243, generator=3.54, feature=0.149, discriminator=1.57]\n", + "100%|██████| 2/2 [00:00<00:00, 4.01it/s, kld=69.7, perceptual=0.252, generator=3.62, feature=0.141, discriminator=1.46]" ] }, { @@ -685,13 +663,13 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.25it/s, kld=159, perceptual=0.347, generator=2.92, feature=0.185, discriminator=1.76]\n", - " 0%| | 0/2 [00:00" ] @@ -703,7 +681,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|███████| 2/2 [00:00<00:00, 3.39it/s, kld=149, perceptual=0.372, generator=2.91, feature=0.186, discriminator=1.79]" + "100%|████████| 2/2 [00:00<00:00, 3.42it/s, kld=109, perceptual=0.258, generator=3.54, feature=0.128, discriminator=1.6]" ] }, { @@ -718,8 +696,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.20it/s, kld=155, perceptual=0.336, generator=2.94, feature=0.162, discriminator=2.13]\n", - "100%|██████| 2/2 [00:00<00:00, 3.88it/s, kld=98.6, perceptual=0.314, generator=2.96, feature=0.145, discriminator=2.31]" + "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=413, perceptual=0.325, generator=3.21, feature=0.144, discriminator=2.18]\n", + "100%|███████| 2/2 [00:00<00:00, 3.80it/s, kld=252, perceptual=0.325, generator=3.39, feature=0.141, discriminator=1.88]" ] }, { @@ -734,8 +712,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█| 15/15 [00:06<00:00, 2.23it/s, kld=1.21e+3, perceptual=0.379, generator=2.97, feature=0.186, discriminator=1.68]\n", - "100%|██████████| 2/2 [00:00<00:00, 4.23it/s, kld=817, perceptual=0.401, generator=3, feature=0.189, discriminator=1.66]" + "100%|█████| 15/15 [00:06<00:00, 2.26it/s, kld=238, perceptual=0.311, generator=3.55, feature=0.169, discriminator=1.67]\n", + "100%|████████| 2/2 [00:00<00:00, 4.29it/s, kld=214, perceptual=0.317, generator=3.57, feature=0.17, discriminator=1.61]" ] }, { @@ -750,8 +728,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.21it/s, kld=118, perceptual=0.377, generator=3.04, feature=0.168, discriminator=2.01]\n", - "100%|██████| 2/2 [00:00<00:00, 3.45it/s, kld=72.6, perceptual=0.365, generator=3.06, feature=0.179, discriminator=1.59]" + "100%|█████| 15/15 [00:06<00:00, 2.25it/s, kld=110, perceptual=0.304, generator=3.57, feature=0.177, discriminator=1.55]\n", + "100%|███████| 2/2 [00:00<00:00, 4.01it/s, kld=53.1, perceptual=0.314, generator=3.58, feature=0.176, discriminator=1.5]" ] }, { @@ -766,8 +744,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█| 15/15 [00:06<00:00, 2.24it/s, kld=2.05e+3, perceptual=0.386, generator=2.48, feature=0.177, discriminator=2.18]\n", - "100%|███| 2/2 [00:00<00:00, 4.27it/s, kld=1.31e+3, perceptual=0.326, generator=2.54, feature=0.166, discriminator=2.24]" + "100%|████████| 15/15 [00:06<00:00, 2.16it/s, kld=129, perceptual=0.27, generator=3.6, feature=0.17, discriminator=1.52]\n", + "100%|████████| 2/2 [00:00<00:00, 4.45it/s, kld=83.5, perceptual=0.295, generator=3.6, feature=0.17, discriminator=1.47]" ] }, { @@ -782,8 +760,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█| 15/15 [00:06<00:00, 2.18it/s, kld=1.19e+3, perceptual=0.337, generator=2.82, feature=0.173, discriminator=1.79]\n", - "100%|████████| 2/2 [00:00<00:00, 4.00it/s, kld=730, perceptual=0.329, generator=2.82, feature=0.17, discriminator=1.78]" + "100%|████| 15/15 [00:06<00:00, 2.29it/s, kld=74.5, perceptual=0.254, generator=3.59, feature=0.166, discriminator=1.52]\n", + "100%|███████| 2/2 [00:00<00:00, 4.49it/s, kld=38.9, perceptual=0.285, generator=3.6, feature=0.166, discriminator=1.47]" ] }, { @@ -798,8 +776,8 @@ "output_type": "stream", "text": [ "\n", - "100%|██████| 15/15 [00:06<00:00, 2.18it/s, kld=383, perceptual=0.348, generator=2.8, feature=0.144, discriminator=1.78]\n", - "100%|████████| 2/2 [00:00<00:00, 3.79it/s, kld=204, perceptual=0.328, generator=2.87, feature=0.146, discriminator=1.8]" + "100%|██████| 15/15 [00:06<00:00, 2.29it/s, kld=103, perceptual=0.25, generator=3.63, feature=0.162, discriminator=1.45]\n", + "100%|███████| 2/2 [00:00<00:00, 4.36it/s, kld=80.4, perceptual=0.246, generator=3.6, feature=0.159, discriminator=1.47]" ] }, { @@ -814,8 +792,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.22it/s, kld=327, perceptual=0.349, generator=2.62, feature=0.0884, discriminator=3.06]\n", - "100%|██████| 2/2 [00:00<00:00, 4.21it/s, kld=198, perceptual=0.337, generator=2.67, feature=0.0891, discriminator=3.01]" + "100%|████| 15/15 [00:06<00:00, 2.28it/s, kld=95.8, perceptual=0.247, generator=3.62, feature=0.162, discriminator=1.44]\n", + "100%|████████| 2/2 [00:00<00:00, 4.44it/s, kld=60, perceptual=0.233, generator=3.51, feature=0.166, discriminator=1.59]" ] }, { @@ -830,8 +808,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.17it/s, kld=780, perceptual=0.363, generator=2.98, feature=0.0928, discriminator=2.89]\n", - "100%|██████| 2/2 [00:00<00:00, 3.25it/s, kld=522, perceptual=0.379, generator=2.94, feature=0.0919, discriminator=2.82]" + "100%|█████| 15/15 [00:06<00:00, 2.33it/s, kld=476, perceptual=0.248, generator=3.36, feature=0.152, discriminator=1.84]\n", + "100%|███████| 2/2 [00:00<00:00, 2.86it/s, kld=181, perceptual=0.251, generator=3.59, feature=0.172, discriminator=1.53]" ] }, { @@ -846,8 +824,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=711, perceptual=0.346, generator=1.89, feature=0.0815, discriminator=3.88]\n", - "100%|██████| 2/2 [00:00<00:00, 4.29it/s, kld=413, perceptual=0.327, generator=2.01, feature=0.0848, discriminator=3.68]" + "100%|██████| 15/15 [00:06<00:00, 2.23it/s, kld=478, perceptual=0.285, generator=3.52, feature=0.15, discriminator=1.87]\n", + "100%|██████| 2/2 [00:00<00:00, 4.42it/s, kld=483, perceptual=0.304, generator=2.73, feature=0.0971, discriminator=2.46]" ] }, { @@ -862,13 +840,13 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.17it/s, kld=329, perceptual=0.339, generator=2.13, feature=0.0703, discriminator=3.87]\n", - " 0%| | 0/2 [00:00" ] @@ -880,7 +858,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|███████| 2/2 [00:00<00:00, 3.01it/s, kld=203, perceptual=0.336, generator=1.99, feature=0.078, discriminator=3.71]" + "100%|███████| 2/2 [00:00<00:00, 2.95it/s, kld=530, perceptual=0.346, generator=3.16, feature=0.143, discriminator=1.94]" ] }, { @@ -895,8 +873,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=174, perceptual=0.353, generator=2.58, feature=0.0856, discriminator=3.15]\n", - "100%|██████| 2/2 [00:00<00:00, 4.24it/s, kld=114, perceptual=0.274, generator=2.73, feature=0.0881, discriminator=3.83]" + "100%|██████| 15/15 [00:06<00:00, 2.24it/s, kld=143, perceptual=0.382, generator=3.6, feature=0.218, discriminator=1.46]\n", + "100%|██████| 2/2 [00:00<00:00, 4.35it/s, kld=80.1, perceptual=0.399, generator=3.61, feature=0.225, discriminator=1.44]" ] }, { @@ -911,8 +889,8 @@ "output_type": "stream", "text": [ "\n", - "100%|██████| 15/15 [00:06<00:00, 2.21it/s, kld=195, perceptual=0.36, generator=2.49, feature=0.106, discriminator=3.85]\n", - "100%|███████| 2/2 [00:00<00:00, 3.47it/s, kld=141, perceptual=0.343, generator=2.38, feature=0.0823, discriminator=3.8]" + "100%|█████| 15/15 [00:06<00:00, 2.25it/s, kld=179, perceptual=0.344, generator=3.59, feature=0.209, discriminator=1.51]\n", + "100%|██████| 2/2 [00:00<00:00, 4.43it/s, kld=91.7, perceptual=0.366, generator=3.62, feature=0.217, discriminator=1.38]" ] }, { @@ -927,8 +905,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=407, perceptual=0.355, generator=2.13, feature=0.0831, discriminator=3.37]\n", - "100%|███████| 2/2 [00:00<00:00, 4.19it/s, kld=190, perceptual=0.316, generator=2.47, feature=0.108, discriminator=3.25]" + "100%|██████| 15/15 [00:06<00:00, 2.21it/s, kld=336, perceptual=0.31, generator=3.59, feature=0.182, discriminator=1.49]\n", + "100%|████████| 2/2 [00:00<00:00, 4.46it/s, kld=216, perceptual=0.311, generator=3.6, feature=0.185, discriminator=1.48]" ] }, { @@ -943,8 +921,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.18it/s, kld=380, perceptual=0.375, generator=1.69, feature=0.108, discriminator=3.62]\n", - "100%|███████| 2/2 [00:00<00:00, 4.01it/s, kld=234, perceptual=0.348, generator=1.86, feature=0.103, discriminator=3.33]" + "100%|█████| 15/15 [00:06<00:00, 2.23it/s, kld=423, perceptual=0.271, generator=3.58, feature=0.157, discriminator=1.53]\n", + "100%|███████| 2/2 [00:00<00:00, 4.07it/s, kld=136, perceptual=0.269, generator=3.59, feature=0.168, discriminator=1.52]" ] }, { @@ -959,8 +937,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.19it/s, kld=207, perceptual=0.345, generator=1.8, feature=0.0924, discriminator=3.78]\n", - "100%|█████████| 2/2 [00:00<00:00, 4.41it/s, kld=129, perceptual=0.348, generator=1.97, feature=0.1, discriminator=3.72]" + "100%|██████| 15/15 [00:06<00:00, 2.26it/s, kld=349, perceptual=0.276, generator=3.6, feature=0.155, discriminator=1.45]\n", + "100%|███████| 2/2 [00:00<00:00, 4.41it/s, kld=220, perceptual=0.272, generator=3.61, feature=0.156, discriminator=1.45]" ] }, { @@ -975,8 +953,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.21it/s, kld=294, perceptual=0.334, generator=2.76, feature=0.119, discriminator=2.83]\n", - "100%|████████| 2/2 [00:00<00:00, 4.38it/s, kld=211, perceptual=0.355, generator=2.52, feature=0.123, discriminator=2.9]" + "100%|█████| 15/15 [00:06<00:00, 2.31it/s, kld=109, perceptual=0.265, generator=3.62, feature=0.156, discriminator=1.46]\n", + "100%|███████| 2/2 [00:00<00:00, 4.00it/s, kld=117, perceptual=0.246, generator=3.59, feature=0.157, discriminator=1.52]" ] }, { @@ -991,8 +969,8 @@ "output_type": "stream", "text": [ "\n", - "100%|██████| 15/15 [00:06<00:00, 2.20it/s, kld=384, perceptual=0.361, generator=2.86, feature=0.126, discriminator=2.7]\n", - "100%|███████| 2/2 [00:00<00:00, 3.77it/s, kld=164, perceptual=0.341, generator=3.05, feature=0.147, discriminator=2.77]" + "100%|█████| 15/15 [00:06<00:00, 2.27it/s, kld=209, perceptual=0.235, generator=3.61, feature=0.154, discriminator=1.47]\n", + "100%|███████| 2/2 [00:00<00:00, 4.47it/s, kld=120, perceptual=0.256, generator=3.61, feature=0.156, discriminator=1.43]" ] }, { @@ -1007,8 +985,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.22it/s, kld=257, perceptual=0.354, generator=1.4, feature=0.0816, discriminator=3.77]\n", - "100%|███████| 2/2 [00:00<00:00, 4.09it/s, kld=173, perceptual=0.356, generator=1.36, feature=0.0792, discriminator=3.7]" + "100%|██████| 15/15 [00:06<00:00, 2.33it/s, kld=146, perceptual=0.231, generator=3.61, feature=0.15, discriminator=1.44]\n", + "100%|██████| 2/2 [00:00<00:00, 4.35it/s, kld=76.3, perceptual=0.248, generator=3.61, feature=0.153, discriminator=1.41]" ] }, { @@ -1023,8 +1001,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.17it/s, kld=166, perceptual=0.366, generator=1.22, feature=0.0643, discriminator=3.9]\n", - "100%|██████| 2/2 [00:00<00:00, 3.98it/s, kld=92.5, perceptual=0.314, generator=1.46, feature=0.065, discriminator=3.55]" + "100%|█████| 15/15 [00:06<00:00, 2.28it/s, kld=180, perceptual=0.225, generator=3.59, feature=0.147, discriminator=1.45]\n", + "100%|███████| 2/2 [00:00<00:00, 4.18it/s, kld=98.1, perceptual=0.228, generator=3.6, feature=0.149, discriminator=1.43]" ] }, { @@ -1039,13 +1017,13 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=171, perceptual=0.331, generator=1.94, feature=0.0693, discriminator=3.74]\n", - " 0%| | 0/2 [00:00" ] @@ -1057,7 +1035,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████| 2/2 [00:00<00:00, 3.36it/s, kld=143, perceptual=0.311, generator=1.89, feature=0.0648, discriminator=3.33]" + "100%|█████████| 2/2 [00:00<00:00, 3.31it/s, kld=90, perceptual=0.226, generator=3.6, feature=0.142, discriminator=1.43]" ] }, { @@ -1072,8 +1050,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.16it/s, kld=264, perceptual=0.363, generator=1.98, feature=0.0766, discriminator=3.99]\n", - "100%|██████| 2/2 [00:00<00:00, 3.75it/s, kld=178, perceptual=0.348, generator=1.96, feature=0.0681, discriminator=3.98]" + "100%|██████| 15/15 [00:06<00:00, 2.21it/s, kld=114, perceptual=0.227, generator=3.63, feature=0.145, discriminator=1.4]\n", + "100%|██████| 2/2 [00:00<00:00, 4.38it/s, kld=62.7, perceptual=0.212, generator=3.64, feature=0.144, discriminator=1.32]" ] }, { @@ -1088,8 +1066,8 @@ "output_type": "stream", "text": [ "\n", - "100%|███| 15/15 [00:06<00:00, 2.20it/s, kld=85.9, perceptual=0.308, generator=2.05, feature=0.0725, discriminator=3.94]\n", - "100%|██████| 2/2 [00:00<00:00, 3.43it/s, kld=55.1, perceptual=0.327, generator=2.17, feature=0.113, discriminator=3.84]" + "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=102, perceptual=0.229, generator=3.62, feature=0.149, discriminator=1.36]\n", + "100%|██████| 2/2 [00:00<00:00, 4.42it/s, kld=89.2, perceptual=0.233, generator=3.63, feature=0.143, discriminator=1.36]" ] }, { @@ -1104,8 +1082,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.20it/s, kld=110, perceptual=0.293, generator=2.17, feature=0.0944, discriminator=3.92]\n", - "100%|██████| 2/2 [00:00<00:00, 3.94it/s, kld=110, perceptual=0.329, generator=2.01, feature=0.0707, discriminator=4.03]" + "100%|███████| 15/15 [00:06<00:00, 2.23it/s, kld=155, perceptual=0.2, generator=3.61, feature=0.144, discriminator=1.38]\n", + "100%|██████| 2/2 [00:00<00:00, 4.19it/s, kld=70.7, perceptual=0.226, generator=3.61, feature=0.144, discriminator=1.41]" ] }, { @@ -1120,8 +1098,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.26it/s, kld=203, perceptual=0.329, generator=2.42, feature=0.0917, discriminator=3.65]\n", - "100%|███████| 2/2 [00:00<00:00, 2.71it/s, kld=106, perceptual=0.29, generator=2.39, feature=0.0954, discriminator=3.58]" + "100%|██████| 15/15 [00:06<00:00, 2.24it/s, kld=118, perceptual=0.217, generator=3.6, feature=0.144, discriminator=1.38]\n", + "100%|██████| 2/2 [00:00<00:00, 4.08it/s, kld=77.5, perceptual=0.204, generator=3.62, feature=0.145, discriminator=1.37]" ] }, { @@ -1136,8 +1114,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=282, perceptual=0.28, generator=2.44, feature=0.0842, discriminator=3.62]\n", - "100%|██████| 2/2 [00:00<00:00, 4.21it/s, kld=206, perceptual=0.284, generator=2.42, feature=0.0789, discriminator=3.64]" + "100%|█████| 15/15 [00:06<00:00, 2.21it/s, kld=128, perceptual=0.199, generator=3.49, feature=0.138, discriminator=1.54]\n", + "100%|██████| 2/2 [00:00<00:00, 4.23it/s, kld=72.3, perceptual=0.219, generator=3.58, feature=0.146, discriminator=1.42]" ] }, { @@ -1152,8 +1130,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.20it/s, kld=313, perceptual=0.333, generator=2.3, feature=0.0794, discriminator=4.01]\n", - "100%|███████| 2/2 [00:00<00:00, 4.14it/s, kld=192, perceptual=0.285, generator=2.4, feature=0.0781, discriminator=4.01]" + "100%|█████| 15/15 [00:06<00:00, 2.29it/s, kld=137, perceptual=0.206, generator=3.62, feature=0.144, discriminator=1.37]\n", + "100%|██████| 2/2 [00:00<00:00, 4.47it/s, kld=98.9, perceptual=0.222, generator=3.63, feature=0.145, discriminator=1.42]" ] }, { @@ -1168,8 +1146,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=101, perceptual=0.297, generator=2.4, feature=0.0699, discriminator=3.96]\n", - "100%|█████| 2/2 [00:00<00:00, 3.69it/s, kld=56.2, perceptual=0.279, generator=2.44, feature=0.0671, discriminator=3.96]" + "100%|█████| 15/15 [00:06<00:00, 2.18it/s, kld=250, perceptual=0.227, generator=3.61, feature=0.148, discriminator=1.47]\n", + "100%|███████| 2/2 [00:00<00:00, 4.23it/s, kld=142, perceptual=0.217, generator=3.58, feature=0.148, discriminator=1.45]" ] }, { @@ -1184,8 +1162,8 @@ "output_type": "stream", "text": [ "\n", - "100%|███| 15/15 [00:06<00:00, 2.17it/s, kld=94.7, perceptual=0.296, generator=2.37, feature=0.0693, discriminator=3.91]\n", - "100%|███████| 2/2 [00:00<00:00, 4.23it/s, kld=64, perceptual=0.239, generator=2.42, feature=0.0572, discriminator=3.95]" + "100%|█████| 15/15 [00:06<00:00, 2.32it/s, kld=101, perceptual=0.208, generator=3.62, feature=0.145, discriminator=1.36]\n", + "100%|██████| 2/2 [00:00<00:00, 4.44it/s, kld=66.2, perceptual=0.211, generator=3.63, feature=0.145, discriminator=1.35]" ] }, { @@ -1200,8 +1178,8 @@ "output_type": "stream", "text": [ "\n", - "100%|██████| 15/15 [00:06<00:00, 2.22it/s, kld=154, perceptual=0.293, generator=2.28, feature=0.055, discriminator=3.9]\n", - "100%|███████| 2/2 [00:00<00:00, 3.99it/s, kld=93.5, perceptual=0.27, generator=2.25, feature=0.0545, discriminator=3.9]" + "100%|████| 15/15 [00:06<00:00, 2.25it/s, kld=99.7, perceptual=0.209, generator=3.61, feature=0.146, discriminator=1.37]\n", + "100%|██████| 2/2 [00:00<00:00, 3.83it/s, kld=74.2, perceptual=0.202, generator=3.62, feature=0.143, discriminator=1.38]" ] }, { @@ -1216,13 +1194,13 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.20it/s, kld=268, perceptual=0.227, generator=2.26, feature=0.052, discriminator=3.98]\n", - " 0%| | 0/2 [00:00" ] @@ -1234,7 +1212,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|███████| 2/2 [00:00<00:00, 3.34it/s, kld=247, perceptual=0.269, generator=2.15, feature=0.056, discriminator=3.95]" + "100%|██████| 2/2 [00:00<00:00, 3.33it/s, kld=86.4, perceptual=0.205, generator=3.63, feature=0.142, discriminator=1.36]" ] }, { @@ -1249,8 +1227,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=195, perceptual=0.231, generator=2.14, feature=0.054, discriminator=4.01]\n", - "100%|██████| 2/2 [00:00<00:00, 4.25it/s, kld=128, perceptual=0.238, generator=2.14, feature=0.0585, discriminator=4.02]" + "100%|█████| 15/15 [00:06<00:00, 2.22it/s, kld=87.2, perceptual=0.208, generator=3.64, feature=0.138, discriminator=1.3]\n", + "100%|██████| 2/2 [00:00<00:00, 3.98it/s, kld=73.5, perceptual=0.197, generator=3.62, feature=0.141, discriminator=1.35]" ] }, { @@ -1265,8 +1243,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.18it/s, kld=312, perceptual=0.238, generator=2.29, feature=0.0563, discriminator=3.91]\n", - "100%|███████| 2/2 [00:00<00:00, 3.29it/s, kld=217, perceptual=0.24, generator=2.29, feature=0.0595, discriminator=3.89]" + "100%|█████| 15/15 [00:06<00:00, 2.25it/s, kld=120, perceptual=0.219, generator=3.62, feature=0.148, discriminator=1.34]\n", + "100%|██████| 2/2 [00:00<00:00, 4.48it/s, kld=55.8, perceptual=0.207, generator=3.63, feature=0.136, discriminator=1.33]" ] }, { @@ -1281,8 +1259,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=653, perceptual=0.225, generator=2.5, feature=0.0637, discriminator=3.79]\n", - "100%|██████| 2/2 [00:00<00:00, 4.31it/s, kld=390, perceptual=0.256, generator=2.43, feature=0.0642, discriminator=3.72]" + "100%|█████| 15/15 [00:06<00:00, 2.27it/s, kld=95.9, perceptual=0.198, generator=3.64, feature=0.135, discriminator=1.3]\n", + "100%|██████| 2/2 [00:00<00:00, 4.34it/s, kld=57.2, perceptual=0.209, generator=3.66, feature=0.135, discriminator=1.25]" ] }, { @@ -1297,8 +1275,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.21it/s, kld=860, perceptual=0.261, generator=2.38, feature=0.0658, discriminator=3.85]\n", - "100%|██████| 2/2 [00:00<00:00, 3.70it/s, kld=513, perceptual=0.241, generator=2.41, feature=0.0746, discriminator=3.82]" + "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=128, perceptual=0.197, generator=3.62, feature=0.132, discriminator=1.34]\n", + "100%|██████| 2/2 [00:00<00:00, 3.98it/s, kld=82.2, perceptual=0.182, generator=3.62, feature=0.124, discriminator=1.35]" ] }, { @@ -1313,8 +1291,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.22it/s, kld=422, perceptual=0.282, generator=2.38, feature=0.0676, discriminator=3.99]\n", - "100%|███████| 2/2 [00:00<00:00, 4.23it/s, kld=325, perceptual=0.234, generator=2.5, feature=0.0701, discriminator=4.01]" + "100%|█████| 15/15 [00:06<00:00, 2.19it/s, kld=108, perceptual=0.184, generator=3.63, feature=0.126, discriminator=1.32]\n", + "100%|██████| 2/2 [00:00<00:00, 4.10it/s, kld=72.3, perceptual=0.195, generator=3.64, feature=0.132, discriminator=1.29]" ] }, { @@ -1329,8 +1307,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█| 15/15 [00:06<00:00, 2.19it/s, kld=2.07e+3, perceptual=0.257, generator=2.42, feature=0.0706, discriminator=3.89\n", - "100%|██| 2/2 [00:00<00:00, 3.85it/s, kld=1.21e+3, perceptual=0.247, generator=2.43, feature=0.0692, discriminator=3.88]" + "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=115, perceptual=0.188, generator=3.62, feature=0.126, discriminator=1.32]\n", + "100%|██████| 2/2 [00:00<00:00, 4.20it/s, kld=60.3, perceptual=0.213, generator=3.65, feature=0.129, discriminator=1.27]" ] }, { @@ -1345,8 +1323,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.18it/s, kld=338, perceptual=0.263, generator=2.39, feature=0.0861, discriminator=4.03]\n", - "100%|███████| 2/2 [00:00<00:00, 4.18it/s, kld=248, perceptual=0.246, generator=2.5, feature=0.0764, discriminator=3.86]" + "100%|█████| 15/15 [00:06<00:00, 2.18it/s, kld=105, perceptual=0.206, generator=3.64, feature=0.127, discriminator=1.28]\n", + "100%|██████| 2/2 [00:00<00:00, 4.34it/s, kld=59.4, perceptual=0.202, generator=3.66, feature=0.125, discriminator=1.26]" ] }, { @@ -1361,8 +1339,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=331, perceptual=0.246, generator=2.47, feature=0.0839, discriminator=3.71]\n", - "100%|███████| 2/2 [00:00<00:00, 4.31it/s, kld=205, perceptual=0.243, generator=2.39, feature=0.088, discriminator=3.79]" + "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=106, perceptual=0.176, generator=3.63, feature=0.128, discriminator=1.32]\n", + "100%|███████| 2/2 [00:00<00:00, 4.31it/s, kld=55.1, perceptual=0.21, generator=3.64, feature=0.134, discriminator=1.29]" ] }, { @@ -1377,8 +1355,8 @@ "output_type": "stream", "text": [ "\n", - "100%|██| 15/15 [00:06<00:00, 2.19it/s, kld=3.41e+3, perceptual=0.248, generator=2.58, feature=0.12, discriminator=3.76]\n", - "100%|████| 2/2 [00:00<00:00, 3.71it/s, kld=2.19e+3, perceptual=0.26, generator=2.55, feature=0.125, discriminator=3.59]" + "100%|████| 15/15 [00:06<00:00, 2.26it/s, kld=88.5, perceptual=0.192, generator=3.62, feature=0.131, discriminator=1.33]\n", + "100%|████████| 2/2 [00:00<00:00, 4.29it/s, kld=63.5, perceptual=0.2, generator=3.61, feature=0.131, discriminator=1.35]" ] }, { @@ -1393,13 +1371,13 @@ "output_type": "stream", "text": [ "\n", - "100%|█| 15/15 [00:06<00:00, 2.21it/s, kld=1.09e+3, perceptual=0.246, generator=2.76, feature=0.128, discriminator=3.32]\n", - " 0%| | 0/2 [00:00" ] @@ -1411,7 +1389,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|███████| 2/2 [00:00<00:00, 3.31it/s, kld=704, perceptual=0.241, generator=2.83, feature=0.145, discriminator=3.18]" + "100%|██████| 2/2 [00:00<00:00, 3.45it/s, kld=52.4, perceptual=0.203, generator=3.65, feature=0.133, discriminator=1.27]" ] }, { @@ -1426,8 +1404,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█| 15/15 [00:06<00:00, 2.19it/s, kld=1.35e+3, perceptual=0.247, generator=1.95, feature=0.0697, discriminator=3.94\n", - "100%|██████| 2/2 [00:00<00:00, 4.27it/s, kld=891, perceptual=0.269, generator=1.84, feature=0.0785, discriminator=4.04]" + "100%|████| 15/15 [00:06<00:00, 2.27it/s, kld=87.7, perceptual=0.192, generator=3.62, feature=0.131, discriminator=1.34]\n", + "100%|███████| 2/2 [00:00<00:00, 4.27it/s, kld=50.3, perceptual=0.207, generator=3.64, feature=0.133, discriminator=1.3]" ] }, { @@ -1442,8 +1420,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.25it/s, kld=373, perceptual=0.286, generator=1.98, feature=0.0785, discriminator=3.87]\n", - "100%|███████| 2/2 [00:00<00:00, 3.92it/s, kld=274, perceptual=0.285, generator=2.02, feature=0.0876, discriminator=3.8]" + "100%|████| 15/15 [00:06<00:00, 2.30it/s, kld=91.9, perceptual=0.184, generator=3.62, feature=0.127, discriminator=1.33]\n", + "100%|██████| 2/2 [00:00<00:00, 4.32it/s, kld=65.7, perceptual=0.182, generator=3.62, feature=0.128, discriminator=1.35]" ] }, { @@ -1458,8 +1436,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.16it/s, kld=637, perceptual=0.264, generator=2.5, feature=0.0888, discriminator=3.63]\n", - "100%|██████| 2/2 [00:00<00:00, 3.50it/s, kld=417, perceptual=0.268, generator=2.46, feature=0.0777, discriminator=3.49]" + "100%|████| 15/15 [00:06<00:00, 2.29it/s, kld=84.6, perceptual=0.193, generator=3.62, feature=0.136, discriminator=1.34]\n", + "100%|██████| 2/2 [00:00<00:00, 4.45it/s, kld=67.2, perceptual=0.176, generator=3.63, feature=0.129, discriminator=1.32]" ] }, { @@ -1474,8 +1452,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.23it/s, kld=584, perceptual=0.274, generator=2.46, feature=0.095, discriminator=3.23]\n", - "100%|██████| 2/2 [00:00<00:00, 4.24it/s, kld=420, perceptual=0.254, generator=2.53, feature=0.0873, discriminator=3.37]" + "100%|████████| 15/15 [00:06<00:00, 2.31it/s, kld=91.9, perceptual=0.2, generator=3.64, feature=0.14, discriminator=1.3]\n", + "100%|██████| 2/2 [00:00<00:00, 4.49it/s, kld=43.8, perceptual=0.204, generator=3.64, feature=0.133, discriminator=1.28]" ] }, { @@ -1490,8 +1468,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.19it/s, kld=842, perceptual=0.272, generator=1.82, feature=0.0686, discriminator=3.99]\n", - "100%|██████| 2/2 [00:00<00:00, 4.11it/s, kld=521, perceptual=0.256, generator=1.91, feature=0.0724, discriminator=3.97]" + "100%|█████| 15/15 [00:06<00:00, 2.29it/s, kld=118, perceptual=0.183, generator=3.64, feature=0.129, discriminator=1.28]\n", + "100%|███████| 2/2 [00:00<00:00, 4.44it/s, kld=71.7, perceptual=0.195, generator=3.6, feature=0.141, discriminator=1.37]" ] }, { @@ -1506,8 +1484,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.20it/s, kld=301, perceptual=0.292, generator=2.01, feature=0.0822, discriminator=3.88]\n", - "100%|███████| 2/2 [00:00<00:00, 4.03it/s, kld=169, perceptual=0.291, generator=2.27, feature=0.103, discriminator=3.59]" + "100%|██████| 15/15 [00:06<00:00, 2.35it/s, kld=102, perceptual=0.202, generator=3.63, feature=0.137, discriminator=1.3]\n", + "100%|██████| 2/2 [00:00<00:00, 4.18it/s, kld=64.4, perceptual=0.188, generator=3.62, feature=0.133, discriminator=1.33]" ] }, { @@ -1522,8 +1500,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.15it/s, kld=835, perceptual=0.302, generator=2.35, feature=0.104, discriminator=3.89]\n", - "100%|█████████| 2/2 [00:00<00:00, 4.33it/s, kld=517, perceptual=0.341, generator=2.3, feature=0.119, discriminator=3.6]" + "100%|████| 15/15 [00:06<00:00, 2.28it/s, kld=96.8, perceptual=0.181, generator=3.61, feature=0.125, discriminator=1.36]\n", + "100%|██████| 2/2 [00:00<00:00, 4.54it/s, kld=54.1, perceptual=0.201, generator=3.59, feature=0.136, discriminator=1.36]" ] }, { @@ -1538,8 +1516,8 @@ "output_type": "stream", "text": [ "\n", - "100%|██████| 15/15 [00:06<00:00, 2.18it/s, kld=425, perceptual=0.322, generator=2.7, feature=0.174, discriminator=3.03]\n", - "100%|███████| 2/2 [00:00<00:00, 4.24it/s, kld=245, perceptual=0.326, generator=2.72, feature=0.162, discriminator=3.07]" + "100%|█████| 15/15 [00:06<00:00, 2.33it/s, kld=137, perceptual=0.195, generator=3.62, feature=0.134, discriminator=1.35]\n", + "100%|██████| 2/2 [00:00<00:00, 4.09it/s, kld=64.9, perceptual=0.198, generator=3.63, feature=0.133, discriminator=1.32]" ] }, { @@ -1554,8 +1532,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=139, perceptual=0.302, generator=2.79, feature=0.149, discriminator=2.91]\n", - "100%|██████| 2/2 [00:00<00:00, 4.20it/s, kld=69.9, perceptual=0.284, generator=2.83, feature=0.137, discriminator=2.89]" + "100%|█████| 15/15 [00:06<00:00, 2.33it/s, kld=77.3, perceptual=0.189, generator=3.64, feature=0.13, discriminator=1.29]\n", + "100%|██████| 2/2 [00:00<00:00, 3.47it/s, kld=48.4, perceptual=0.187, generator=3.65, feature=0.132, discriminator=1.26]" ] }, { @@ -1570,13 +1548,13 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.20it/s, kld=210, perceptual=0.278, generator=2.84, feature=0.133, discriminator=2.84]\n", - " 0%| | 0/2 [00:00" ] @@ -1588,7 +1566,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|████████| 2/2 [00:00<00:00, 3.44it/s, kld=132, perceptual=0.257, generator=2.86, feature=0.12, discriminator=2.85]" + "100%|██████| 2/2 [00:00<00:00, 3.59it/s, kld=45.9, perceptual=0.196, generator=3.65, feature=0.133, discriminator=1.26]" ] }, { @@ -1603,8 +1581,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.19it/s, kld=835, perceptual=0.266, generator=1.8, feature=0.0844, discriminator=3.96]\n", - "100%|██████| 2/2 [00:00<00:00, 3.26it/s, kld=767, perceptual=0.293, generator=1.72, feature=0.0979, discriminator=3.83]" + "100%|█████| 15/15 [00:06<00:00, 2.22it/s, kld=91.5, perceptual=0.192, generator=3.64, feature=0.134, discriminator=1.3]\n", + "100%|██████| 2/2 [00:00<00:00, 4.30it/s, kld=51.4, perceptual=0.175, generator=3.63, feature=0.127, discriminator=1.29]" ] }, { @@ -1619,8 +1597,8 @@ "output_type": "stream", "text": [ "\n", - "100%|██████| 15/15 [00:06<00:00, 2.19it/s, kld=908, perceptual=0.306, generator=1.62, feature=0.102, discriminator=3.5]\n", - "100%|███████| 2/2 [00:00<00:00, 4.25it/s, kld=568, perceptual=0.324, generator=1.66, feature=0.107, discriminator=3.45]" + "100%|████| 15/15 [00:06<00:00, 2.24it/s, kld=89.7, perceptual=0.199, generator=3.65, feature=0.132, discriminator=1.25]\n", + "100%|██████| 2/2 [00:00<00:00, 4.25it/s, kld=51.4, perceptual=0.179, generator=3.59, feature=0.133, discriminator=1.36]" ] }, { @@ -1635,8 +1613,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=865, perceptual=0.304, generator=1.54, feature=0.0938, discriminator=3.68]\n", - "100%|██████| 2/2 [00:00<00:00, 3.77it/s, kld=668, perceptual=0.295, generator=1.62, feature=0.0881, discriminator=3.65]" + "100%|██████| 15/15 [00:06<00:00, 2.22it/s, kld=85.4, perceptual=0.18, generator=3.61, feature=0.13, discriminator=1.34]\n", + "100%|██████| 2/2 [00:00<00:00, 3.02it/s, kld=69.2, perceptual=0.195, generator=3.62, feature=0.135, discriminator=1.33]" ] }, { @@ -1651,8 +1629,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.19it/s, kld=915, perceptual=0.272, generator=1.26, feature=0.074, discriminator=3.78]\n", - "100%|██████| 2/2 [00:00<00:00, 4.09it/s, kld=570, perceptual=0.299, generator=1.21, feature=0.0785, discriminator=3.76]" + "100%|██████| 15/15 [00:06<00:00, 2.30it/s, kld=113, perceptual=0.185, generator=3.64, feature=0.126, discriminator=1.3]\n", + "100%|██████| 2/2 [00:00<00:00, 4.46it/s, kld=67.6, perceptual=0.206, generator=3.65, feature=0.128, discriminator=1.27]" ] }, { @@ -1667,8 +1645,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.20it/s, kld=343, perceptual=0.268, generator=1.28, feature=0.0697, discriminator=3.87]\n", - "100%|███████| 2/2 [00:00<00:00, 4.37it/s, kld=209, perceptual=0.278, generator=1.2, feature=0.0735, discriminator=3.81]" + "100%|██████| 15/15 [00:06<00:00, 2.27it/s, kld=79, perceptual=0.163, generator=3.61, feature=0.124, discriminator=1.33]\n", + "100%|██████| 2/2 [00:00<00:00, 3.91it/s, kld=47.8, perceptual=0.198, generator=3.63, feature=0.133, discriminator=1.28]" ] }, { @@ -1683,8 +1661,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.25it/s, kld=495, perceptual=0.252, generator=1.18, feature=0.0635, discriminator=3.82]\n", - "100%|██████| 2/2 [00:00<00:00, 3.88it/s, kld=337, perceptual=0.267, generator=1.15, feature=0.0676, discriminator=3.88]" + "100%|██████| 15/15 [00:07<00:00, 2.13it/s, kld=80.8, perceptual=0.174, generator=3.63, feature=0.13, discriminator=1.3]\n", + "100%|███████| 2/2 [00:00<00:00, 4.20it/s, kld=62.5, perceptual=0.172, generator=3.63, feature=0.134, discriminator=1.3]" ] }, { @@ -1699,8 +1677,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.21it/s, kld=677, perceptual=0.271, generator=1.25, feature=0.0701, discriminator=3.95]\n", - "100%|██████| 2/2 [00:00<00:00, 4.08it/s, kld=563, perceptual=0.279, generator=1.11, feature=0.0637, discriminator=3.93]" + "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=88.8, perceptual=0.208, generator=3.64, feature=0.135, discriminator=1.28]\n", + "100%|██████| 2/2 [00:00<00:00, 4.25it/s, kld=48.4, perceptual=0.196, generator=3.64, feature=0.135, discriminator=1.28]" ] }, { @@ -1715,8 +1693,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.29it/s, kld=373, perceptual=0.283, generator=1.24, feature=0.0668, discriminator=3.91]\n", - "100%|██████| 2/2 [00:00<00:00, 3.75it/s, kld=235, perceptual=0.266, generator=1.25, feature=0.0631, discriminator=3.89]" + "100%|█████| 15/15 [00:06<00:00, 2.19it/s, kld=100, perceptual=0.202, generator=3.63, feature=0.135, discriminator=1.29]\n", + "100%|██████| 2/2 [00:00<00:00, 4.05it/s, kld=61.3, perceptual=0.208, generator=3.65, feature=0.135, discriminator=1.27]" ] }, { @@ -1731,8 +1709,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.20it/s, kld=455, perceptual=0.275, generator=1.23, feature=0.0644, discriminator=3.96]\n", - "100%|██████| 2/2 [00:00<00:00, 4.21it/s, kld=363, perceptual=0.263, generator=1.19, feature=0.0624, discriminator=3.91]" + "100%|█████| 15/15 [00:06<00:00, 2.29it/s, kld=116, perceptual=0.201, generator=3.64, feature=0.134, discriminator=1.29]\n", + "100%|████████| 2/2 [00:00<00:00, 4.46it/s, kld=68, perceptual=0.188, generator=3.64, feature=0.132, discriminator=1.29]" ] }, { @@ -1747,13 +1725,13 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.23it/s, kld=519, perceptual=0.264, generator=1.29, feature=0.0628, discriminator=3.9]\n", - " 0%| | 0/2 [00:00" ] @@ -1765,7 +1743,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████| 2/2 [00:00<00:00, 3.35it/s, kld=375, perceptual=0.261, generator=1.29, feature=0.0589, discriminator=3.89]" + "100%|███████| 2/2 [00:00<00:00, 3.44it/s, kld=54.3, perceptual=0.17, generator=3.64, feature=0.126, discriminator=1.29]" ] }, { @@ -1780,8 +1758,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.18it/s, kld=329, perceptual=0.242, generator=1.18, feature=0.0615, discriminator=4.02]\n", - "100%|███████| 2/2 [00:00<00:00, 3.73it/s, kld=205, perceptual=0.267, generator=1.16, feature=0.067, discriminator=3.97]" + "100%|█████| 15/15 [00:06<00:00, 2.21it/s, kld=102, perceptual=0.186, generator=3.65, feature=0.132, discriminator=1.27]\n", + "100%|██████| 2/2 [00:00<00:00, 4.29it/s, kld=63.4, perceptual=0.191, generator=3.66, feature=0.131, discriminator=1.25]" ] }, { @@ -1796,8 +1774,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.18it/s, kld=629, perceptual=0.233, generator=1.37, feature=0.0524, discriminator=3.98]\n", - "100%|██████| 2/2 [00:00<00:00, 4.20it/s, kld=417, perceptual=0.233, generator=1.38, feature=0.0535, discriminator=3.95]" + "100%|█████| 15/15 [00:06<00:00, 2.17it/s, kld=161, perceptual=0.191, generator=3.64, feature=0.132, discriminator=1.29]\n", + "100%|███████| 2/2 [00:00<00:00, 4.33it/s, kld=77.6, perceptual=0.165, generator=3.64, feature=0.13, discriminator=1.27]" ] }, { @@ -1812,8 +1790,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.16it/s, kld=198, perceptual=0.245, generator=1.45, feature=0.0611, discriminator=3.97]\n", - "100%|███████| 2/2 [00:00<00:00, 4.12it/s, kld=129, perceptual=0.263, generator=1.43, feature=0.071, discriminator=3.99]" + "100%|█████| 15/15 [00:06<00:00, 2.23it/s, kld=102, perceptual=0.185, generator=3.61, feature=0.132, discriminator=1.31]\n", + "100%|██████| 2/2 [00:00<00:00, 4.36it/s, kld=64.2, perceptual=0.204, generator=3.64, feature=0.137, discriminator=1.29]" ] }, { @@ -1828,8 +1806,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.22it/s, kld=466, perceptual=0.279, generator=1.78, feature=0.0656, discriminator=3.92]\n", - "100%|██████| 2/2 [00:00<00:00, 3.96it/s, kld=308, perceptual=0.278, generator=1.79, feature=0.0681, discriminator=3.86]" + "100%|█████| 15/15 [00:06<00:00, 2.18it/s, kld=71.9, perceptual=0.188, generator=3.63, feature=0.135, discriminator=1.3]\n", + "100%|██████| 2/2 [00:00<00:00, 4.13it/s, kld=50.9, perceptual=0.206, generator=3.65, feature=0.131, discriminator=1.26]" ] }, { @@ -1844,8 +1822,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.16it/s, kld=736, perceptual=0.245, generator=1.53, feature=0.0834, discriminator=4.15]\n", - "100%|██████| 2/2 [00:00<00:00, 4.08it/s, kld=668, perceptual=0.235, generator=1.41, feature=0.0861, discriminator=4.06]" + "100%|████| 15/15 [00:06<00:00, 2.25it/s, kld=72.6, perceptual=0.179, generator=3.65, feature=0.124, discriminator=1.26]\n", + "100%|███████| 2/2 [00:00<00:00, 4.15it/s, kld=44.3, perceptual=0.182, generator=3.58, feature=0.13, discriminator=1.34]" ] }, { @@ -1860,8 +1838,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.19it/s, kld=815, perceptual=0.277, generator=1.98, feature=0.129, discriminator=3.59]\n", - "100%|███████| 2/2 [00:00<00:00, 4.03it/s, kld=525, perceptual=0.268, generator=2.05, feature=0.133, discriminator=3.32]" + "100%|██████| 15/15 [00:06<00:00, 2.20it/s, kld=165, perceptual=0.177, generator=3.64, feature=0.131, discriminator=1.3]\n", + "100%|████████| 2/2 [00:00<00:00, 4.08it/s, kld=95, perceptual=0.205, generator=3.63, feature=0.138, discriminator=1.31]" ] }, { @@ -1876,8 +1854,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.24it/s, kld=1e+3, perceptual=0.276, generator=2.28, feature=0.115, discriminator=3.72]\n", - "100%|███████| 2/2 [00:00<00:00, 3.66it/s, kld=660, perceptual=0.282, generator=2.32, feature=0.133, discriminator=3.53]" + "100%|█████| 15/15 [00:06<00:00, 2.23it/s, kld=107, perceptual=0.179, generator=3.64, feature=0.126, discriminator=1.28]\n", + "100%|███████| 2/2 [00:00<00:00, 4.41it/s, kld=56.1, perceptual=0.203, generator=3.63, feature=0.138, discriminator=1.3]" ] }, { @@ -1892,8 +1870,8 @@ "output_type": "stream", "text": [ "\n", - "100%|██| 15/15 [00:06<00:00, 2.21it/s, kld=1.36e+3, perceptual=0.279, generator=2.47, feature=0.14, discriminator=2.85]\n", - "100%|███████| 2/2 [00:00<00:00, 4.23it/s, kld=957, perceptual=0.297, generator=2.59, feature=0.138, discriminator=2.82]" + "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=82.6, perceptual=0.174, generator=3.63, feature=0.126, discriminator=1.29]\n", + "100%|███████| 2/2 [00:00<00:00, 4.19it/s, kld=56.4, perceptual=0.19, generator=3.64, feature=0.136, discriminator=1.26]" ] }, { @@ -1908,8 +1886,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█| 15/15 [00:06<00:00, 2.22it/s, kld=2.86e+3, perceptual=0.311, generator=1.84, feature=0.142, discriminator=3.38]\n", - "100%|████| 2/2 [00:00<00:00, 3.54it/s, kld=2.04e+3, perceptual=0.328, generator=1.89, feature=0.148, discriminator=3.5]" + "100%|████| 15/15 [00:06<00:00, 2.29it/s, kld=88.3, perceptual=0.177, generator=3.65, feature=0.123, discriminator=1.26]\n", + "100%|███████| 2/2 [00:00<00:00, 4.33it/s, kld=55.7, perceptual=0.17, generator=3.65, feature=0.128, discriminator=1.24]" ] }, { @@ -1924,13 +1902,13 @@ "output_type": "stream", "text": [ "\n", - "100%|██████| 15/15 [00:06<00:00, 2.17it/s, kld=699, perceptual=0.338, generator=2.35, feature=0.14, discriminator=2.54]\n", - " 0%| | 0/2 [00:00" ] @@ -1942,7 +1920,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|████████| 2/2 [00:00<00:00, 3.51it/s, kld=409, perceptual=0.331, generator=2.58, feature=0.147, discriminator=2.5]" + "100%|██████| 2/2 [00:00<00:00, 3.51it/s, kld=62.6, perceptual=0.191, generator=3.66, feature=0.132, discriminator=1.24]" ] }, { @@ -1957,8 +1935,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.25it/s, kld=491, perceptual=0.355, generator=2.74, feature=0.146, discriminator=2.93]\n", - "100%|███████| 2/2 [00:00<00:00, 2.61it/s, kld=359, perceptual=0.314, generator=2.85, feature=0.143, discriminator=2.95]" + "100%|████| 15/15 [00:06<00:00, 2.27it/s, kld=93.2, perceptual=0.183, generator=3.64, feature=0.132, discriminator=1.29]\n", + "100%|███████| 2/2 [00:00<00:00, 4.39it/s, kld=64.2, perceptual=0.197, generator=3.66, feature=0.13, discriminator=1.24]" ] }, { @@ -1973,8 +1951,8 @@ "output_type": "stream", "text": [ "\n", - "100%|██████| 15/15 [00:06<00:00, 2.24it/s, kld=517, perceptual=0.334, generator=2.67, feature=0.14, discriminator=3.12]\n", - "100%|████████| 2/2 [00:00<00:00, 3.93it/s, kld=384, perceptual=0.326, generator=2.72, feature=0.14, discriminator=3.09]" + "100%|████| 15/15 [00:06<00:00, 2.20it/s, kld=90.5, perceptual=0.176, generator=3.65, feature=0.131, discriminator=1.28]\n", + "100%|██████| 2/2 [00:00<00:00, 4.01it/s, kld=61.5, perceptual=0.199, generator=3.67, feature=0.131, discriminator=1.23]" ] }, { @@ -1989,8 +1967,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.22it/s, kld=1.27e+3, perceptual=0.302, generator=1.8, feature=0.0732, discriminator=4]\n", - "100%|███████| 2/2 [00:00<00:00, 3.94it/s, kld=815, perceptual=0.31, generator=1.76, feature=0.0643, discriminator=3.98]" + "100%|████| 15/15 [00:06<00:00, 2.26it/s, kld=89.7, perceptual=0.184, generator=3.66, feature=0.129, discriminator=1.24]\n", + "100%|██████| 2/2 [00:00<00:00, 4.12it/s, kld=61.7, perceptual=0.181, generator=3.66, feature=0.123, discriminator=1.25]" ] }, { @@ -2005,8 +1983,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█| 15/15 [00:06<00:00, 2.20it/s, kld=1.01e+3, perceptual=0.281, generator=1.79, feature=0.0657, discriminator=3.96\n", - "100%|███████| 2/2 [00:00<00:00, 4.20it/s, kld=590, perceptual=0.263, generator=1.86, feature=0.071, discriminator=3.95]" + "100%|████| 15/15 [00:06<00:00, 2.22it/s, kld=89.6, perceptual=0.181, generator=3.65, feature=0.133, discriminator=1.25]\n", + "100%|███████| 2/2 [00:00<00:00, 4.47it/s, kld=46.8, perceptual=0.178, generator=3.65, feature=0.13, discriminator=1.26]" ] }, { @@ -2021,8 +1999,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.18it/s, kld=465, perceptual=0.253, generator=1.94, feature=0.0742, discriminator=3.63]\n", - "100%|███████| 2/2 [00:00<00:00, 4.08it/s, kld=329, perceptual=0.251, generator=1.95, feature=0.0817, discriminator=3.5]" + "100%|█████| 15/15 [00:06<00:00, 2.27it/s, kld=96.5, perceptual=0.192, generator=3.64, feature=0.13, discriminator=1.28]\n", + "100%|███████| 2/2 [00:00<00:00, 4.30it/s, kld=66.1, perceptual=0.19, generator=3.63, feature=0.133, discriminator=1.29]" ] }, { @@ -2037,8 +2015,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=862, perceptual=0.243, generator=1.09, feature=0.0589, discriminator=3.92]\n", - "100%|██████| 2/2 [00:00<00:00, 3.89it/s, kld=591, perceptual=0.238, generator=1.14, feature=0.0605, discriminator=3.81]" + "100%|█████| 15/15 [00:06<00:00, 2.23it/s, kld=150, perceptual=0.197, generator=3.65, feature=0.134, discriminator=1.27]\n", + "100%|████████| 2/2 [00:00<00:00, 4.41it/s, kld=72, perceptual=0.208, generator=3.62, feature=0.141, discriminator=1.32]" ] }, { @@ -2053,8 +2031,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█| 15/15 [00:06<00:00, 2.18it/s, kld=1.08e+3, perceptual=0.24, generator=1.74, feature=0.0678, discriminator=3.54]\n", - "100%|███████| 2/2 [00:00<00:00, 4.11it/s, kld=677, perceptual=0.231, generator=1.8, feature=0.0649, discriminator=3.91]" + "100%|██████| 15/15 [00:06<00:00, 2.26it/s, kld=119, perceptual=0.178, generator=3.63, feature=0.129, discriminator=1.3]\n", + "100%|██████| 2/2 [00:00<00:00, 4.19it/s, kld=81.4, perceptual=0.176, generator=3.63, feature=0.133, discriminator=1.28]" ] }, { @@ -2069,8 +2047,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█████| 15/15 [00:06<00:00, 2.21it/s, kld=958, perceptual=0.316, generator=1.8, feature=0.0715, discriminator=3.97]\n", - "100%|██████| 2/2 [00:00<00:00, 4.26it/s, kld=618, perceptual=0.301, generator=1.81, feature=0.0668, discriminator=3.96]" + "100%|██████| 15/15 [00:06<00:00, 2.23it/s, kld=102, perceptual=0.173, generator=3.64, feature=0.13, discriminator=1.26]\n", + "100%|███████| 2/2 [00:00<00:00, 4.28it/s, kld=54.9, perceptual=0.192, generator=3.63, feature=0.136, discriminator=1.3]" ] }, { @@ -2085,8 +2063,8 @@ "output_type": "stream", "text": [ "\n", - "100%|███| 15/15 [00:06<00:00, 2.16it/s, kld=1.17e+3, perceptual=0.311, generator=1.31, feature=0.0658, discriminator=4]\n", - "100%|██████| 2/2 [00:00<00:00, 3.89it/s, kld=989, perceptual=0.305, generator=1.23, feature=0.0674, discriminator=3.95]" + "100%|█████| 15/15 [00:06<00:00, 2.22it/s, kld=105, perceptual=0.178, generator=3.63, feature=0.137, discriminator=1.29]\n", + "100%|██████| 2/2 [00:00<00:00, 4.18it/s, kld=85.9, perceptual=0.173, generator=3.64, feature=0.129, discriminator=1.28]" ] }, { @@ -2101,13 +2079,13 @@ "output_type": "stream", "text": [ "\n", - "100%|██| 15/15 [00:06<00:00, 2.23it/s, kld=2.17e+3, perceptual=0.294, generator=1.94, feature=0.101, discriminator=3.3]\n", - " 0%| | 0/2 [00:00" ] @@ -2119,7 +2097,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|█████| 2/2 [00:00<00:00, 2.79it/s, kld=1.5e+3, perceptual=0.301, generator=1.96, feature=0.12, discriminator=3.16]" + "100%|██████| 2/2 [00:00<00:00, 3.52it/s, kld=50.1, perceptual=0.159, generator=3.65, feature=0.126, discriminator=1.28]" ] }, { @@ -2134,8 +2112,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█| 15/15 [00:06<00:00, 2.24it/s, kld=1.88e+3, perceptual=0.29, generator=2.02, feature=0.0721, discriminator=3.34]\n", - "100%|██| 2/2 [00:00<00:00, 4.14it/s, kld=1.26e+3, perceptual=0.249, generator=1.96, feature=0.0852, discriminator=3.27]" + "100%|██████| 15/15 [00:06<00:00, 2.22it/s, kld=120, perceptual=0.175, generator=3.65, feature=0.13, discriminator=1.26]\n", + "100%|██████| 2/2 [00:00<00:00, 3.88it/s, kld=54.5, perceptual=0.193, generator=3.64, feature=0.136, discriminator=1.27]" ] }, { @@ -2150,8 +2128,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█| 15/15 [00:06<00:00, 2.19it/s, kld=1.49e+3, perceptual=0.253, generator=1.96, feature=0.0624, discriminator=3.82\n", - "100%|██████| 2/2 [00:00<00:00, 3.80it/s, kld=862, perceptual=0.256, generator=1.96, feature=0.0589, discriminator=3.98]" + "100%|████| 15/15 [00:06<00:00, 2.16it/s, kld=85.1, perceptual=0.178, generator=3.67, feature=0.133, discriminator=1.22]\n", + "100%|██████| 2/2 [00:00<00:00, 4.22it/s, kld=56.2, perceptual=0.163, generator=3.66, feature=0.129, discriminator=1.24]" ] }, { @@ -2166,8 +2144,8 @@ "output_type": "stream", "text": [ "\n", - "100%|██████| 15/15 [00:06<00:00, 2.23it/s, kld=592, perceptual=0.252, generator=1.11, feature=0.108, discriminator=4.4]\n", - "100%|█████████| 2/2 [00:00<00:00, 3.92it/s, kld=365, perceptual=0.278, generator=1.19, feature=0.13, discriminator=4.5]" + "100%|██████| 15/15 [00:06<00:00, 2.19it/s, kld=98, perceptual=0.177, generator=3.64, feature=0.136, discriminator=1.27]\n", + "100%|██████| 2/2 [00:00<00:00, 4.38it/s, kld=63.8, perceptual=0.192, generator=3.66, feature=0.131, discriminator=1.24]" ] }, { @@ -2182,8 +2160,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█| 15/15 [00:06<00:00, 2.21it/s, kld=1.65e+3, perceptual=0.267, generator=1.77, feature=0.121, discriminator=3.45]\n", - "100%|███| 2/2 [00:00<00:00, 4.14it/s, kld=1.05e+3, perceptual=0.244, generator=1.85, feature=0.123, discriminator=3.18]" + "100%|█████| 15/15 [00:06<00:00, 2.22it/s, kld=117, perceptual=0.189, generator=3.64, feature=0.134, discriminator=1.29]\n", + "100%|██████| 2/2 [00:00<00:00, 4.20it/s, kld=62.4, perceptual=0.158, generator=3.63, feature=0.124, discriminator=1.29]" ] }, { @@ -2198,8 +2176,8 @@ "output_type": "stream", "text": [ "\n", - "100%|█| 15/15 [00:06<00:00, 2.24it/s, kld=1.53e+3, perceptual=0.279, generator=1.74, feature=0.114, discriminator=3.33]\n", - "100%|██████| 2/2 [00:00<00:00, 4.29it/s, kld=826, perceptual=0.259, generator=1.69, feature=0.0888, discriminator=3.51]" + "100%|█████| 15/15 [00:06<00:00, 2.24it/s, kld=115, perceptual=0.186, generator=3.67, feature=0.131, discriminator=1.22]\n", + "100%|██████| 2/2 [00:00<00:00, 4.25it/s, kld=72.7, perceptual=0.161, generator=3.64, feature=0.126, discriminator=1.28]" ] }, { @@ -2214,8 +2192,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.20it/s, kld=940, perceptual=0.264, generator=1.88, feature=0.0612, discriminator=3.97]\n", - "100%|██████| 2/2 [00:00<00:00, 4.26it/s, kld=580, perceptual=0.262, generator=1.55, feature=0.0873, discriminator=4.16]" + "100%|████| 15/15 [00:06<00:00, 2.16it/s, kld=83.2, perceptual=0.161, generator=3.66, feature=0.127, discriminator=1.24]\n", + "100%|████████| 2/2 [00:00<00:00, 4.27it/s, kld=76, perceptual=0.166, generator=3.65, feature=0.128, discriminator=1.27]" ] }, { @@ -2230,8 +2208,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=588, perceptual=0.275, generator=1.81, feature=0.0542, discriminator=3.85]\n", - "100%|██████| 2/2 [00:00<00:00, 3.94it/s, kld=549, perceptual=0.262, generator=1.94, feature=0.0645, discriminator=3.75]" + "100%|█████| 15/15 [00:06<00:00, 2.21it/s, kld=100, perceptual=0.173, generator=3.64, feature=0.128, discriminator=1.27]\n", + "100%|███████| 2/2 [00:00<00:00, 3.91it/s, kld=52.8, perceptual=0.16, generator=3.64, feature=0.122, discriminator=1.27]" ] }, { @@ -2246,8 +2224,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.18it/s, kld=658, perceptual=0.284, generator=1.73, feature=0.0515, discriminator=3.94]\n", - "100%|██████| 2/2 [00:00<00:00, 4.15it/s, kld=362, perceptual=0.271, generator=1.79, feature=0.0583, discriminator=3.94]" + "100%|█████| 15/15 [00:06<00:00, 2.19it/s, kld=105, perceptual=0.188, generator=3.63, feature=0.133, discriminator=1.29]\n", + "100%|██████| 2/2 [00:00<00:00, 4.30it/s, kld=66.9, perceptual=0.182, generator=3.66, feature=0.131, discriminator=1.22]" ] }, { @@ -2262,8 +2240,8 @@ "output_type": "stream", "text": [ "\n", - "100%|████| 15/15 [00:06<00:00, 2.23it/s, kld=808, perceptual=0.281, generator=1.83, feature=0.0545, discriminator=3.86]\n", - "100%|██████| 2/2 [00:00<00:00, 2.92it/s, kld=598, perceptual=0.266, generator=1.79, feature=0.0619, discriminator=3.89]\n" + "100%|██████| 15/15 [00:06<00:00, 2.24it/s, kld=107, perceptual=0.164, generator=3.64, feature=0.126, discriminator=1.3]\n", + "100%|███████| 2/2 [00:00<00:00, 3.99it/s, kld=63.3, perceptual=0.16, generator=3.66, feature=0.129, discriminator=1.25]\n" ] } ], @@ -2271,6 +2249,8 @@ "net = net.to(device)\n", "discriminator = discriminator.to(device)\n", "torch.autograd.set_detect_anomaly(True)\n", + "losses = {'kld': [], 'perceptual': [], 'feature': [], 'generator': [], 'discriminator': []}\n", + "losses_val = {'kld': [], 'perceptual': [], 'feature': [], 'generator': [], 'discriminator': []}\n", "for epoch in range(num_epochs):\n", " print(\"Epoch %d/%d\" %(epoch, num_epochs))\n", " train_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=120)\n", @@ -2355,22 +2335,78 @@ " \"discriminator\": loss_d_r.item() + loss_g_f.item(),\n", " })\n", " if step == 0 and epoch%10==0:\n", - " picture_results(label, image, out)" + " picture_results(label, image, out)\n", + " for key, val in losses_epoch.items():\n", + " losses[key].append(val / len(train_loader))\n", + " for key, val in losses_epoch_val.items():\n", + " losses_val[key].append(val / len(val_loader))\n", + " " ] }, { "cell_type": "code", - "execution_count": null, - "id": "f579376b", + "execution_count": 28, + "id": "d79c8ced", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Plot losses\n", + "colors = ['orangered', 'royalblue', 'hotpink', 'lime', 'goldenrod']\n", + "plt.figure(figsize=(5,10))\n", + "ind = 0\n", + "for key, val in losses.items():\n", + " plt.subplot(len(losses.keys()),1,ind+1)\n", + " plt.plot(val, color = colors[ind], linestyle = '-')\n", + " plt.plot(losses_val[key], color = colors[ind], linestyle = '--')\n", + " plt.title(key)\n", + " plt.xlabel(\"Epochs\")\n", + " ind+=1;\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "c3cf096f", "metadata": { "pycharm": { "name": "#%%" } }, - "outputs": [], "source": [ - "\n" + "**Conclusion**: from early on, the network shows the capability of discern between the different semantic layers. To achieve good image quality, more images and training time are needed (to avoid overfitting, seen in some loss plots of previous example), as well as thorough optimisation, such as establishing an adversarial schedule that makes sure that the discriminator and generator and the discriminator are trained only when their performance does not exceed a certain limit.\n" ] + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8cf9ab5f", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -2392,9 +2428,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/tutorials/generative/2d_spade_gan/2d_spade_vae.py b/tutorials/generative/2d_spade_gan/2d_spade_vae.py index 1f2a93ac..3bb1f04f 100644 --- a/tutorials/generative/2d_spade_gan/2d_spade_vae.py +++ b/tutorials/generative/2d_spade_gan/2d_spade_vae.py @@ -239,6 +239,8 @@ def feature_loss(input_features_disc_fake, input_features_disc_real, lambda_feat net = net.to(device) discriminator = discriminator.to(device) torch.autograd.set_detect_anomaly(True) +losses = {'kld': [], 'perceptual': [], 'feature': [], 'generator': [], 'discriminator': []} +losses_val = {'kld': [], 'perceptual': [], 'feature': [], 'generator': [], 'discriminator': []} for epoch in range(num_epochs): print("Epoch %d/%d" %(epoch, num_epochs)) train_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=120) @@ -324,7 +326,29 @@ def feature_loss(input_features_disc_fake, input_features_disc_real, lambda_feat }) if step == 0 and epoch%10==0: picture_results(label, image, out) + for key, val in losses_epoch.items(): + losses[key].append(val / len(train_loader)) + for key, val in losses_epoch_val.items(): + losses_val[key].append(val / len(val_loader)) + + +# Plot losses +colors = ['orangered', 'royalblue', 'hotpink', 'lime', 'goldenrod'] +plt.figure(figsize=(5,10)) +ind = 0 +for key, val in losses.items(): + plt.subplot(len(losses.keys()),1,ind+1) + plt.plot(val, color = colors[ind], linestyle = '-') + plt.plot(losses_val[key], color = colors[ind], linestyle = '--') + plt.title(key) + plt.xlabel("Epochs") + ind+=1; +plt.tight_layout() +plt.show() -# + pycharm={"name": "#%%"} +# + [markdown] pycharm={"name": "#%%"} +# **Conclusion**: from early on, the network shows the capability of discern between the different semantic layers. To achieve good image quality, more images and training time are needed (to avoid overfitting, seen in some loss plots of previous example), as well as thorough optimisation, such as establishing an adversarial schedule that makes sure that the discriminator and generator and the discriminator are trained only when their performance does not exceed a certain limit. +# +# - From 39ed0f161b31c722c1f99f972a77231415a02b80 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 5 Jul 2023 10:22:11 -0600 Subject: [PATCH 6/7] Addresses hidden comments and runtests with autofix --- generative/losses/kld_loss.py | 5 +- generative/networks/blocks/spade_norm.py | 90 ++--- generative/networks/nets/spade_network.py | 314 ++++++++++-------- tests/test_spade_vaegan.py | 94 +++--- .../2d_spade_gan/2d_spade_vae.ipynb | 46 +-- .../generative/2d_spade_gan/2d_spade_vae.py | 274 +++++++-------- 6 files changed, 434 insertions(+), 389 deletions(-) diff --git a/generative/losses/kld_loss.py b/generative/losses/kld_loss.py index d3178b84..ebcaf52b 100644 --- a/generative/losses/kld_loss.py +++ b/generative/losses/kld_loss.py @@ -9,8 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch.nn as nn +from __future__ import annotations + import torch +import torch.nn as nn + class KLDLoss(nn.Module): def forward(self, mu, logvar): diff --git a/generative/networks/blocks/spade_norm.py b/generative/networks/blocks/spade_norm.py index 419f740e..0fe735e8 100644 --- a/generative/networks/blocks/spade_norm.py +++ b/generative/networks/blocks/spade_norm.py @@ -10,80 +10,84 @@ # limitations under the License. from __future__ import annotations + import torch import torch.nn as nn -from monai.networks.blocks import Convolution, ADN import torch.nn.functional as F +from monai.networks.blocks import ADN, Convolution + class SPADE(nn.Module): """ SPADE normalisation block based on the 2019 paper by Park et al. (doi: https://doi.org/10.48550/arXiv.1903.07291) + Args: label_nc: number of semantic labels norm_nc: number of output channels kernel_size: kernel size spatial_dims: number of spatial dimensions hidden_channels: number of channels in the intermediate gamma and beta layers - normalisation: type of base normalisation used before applying the SPADE normalisation + norm: type of base normalisation used before applying the SPADE normalisation + norm_params: parameters for the base normalisation """ - def __init__(self, - label_nc: int, - norm_nc: int, - kernel_size: int = 3, - spatial_dims: int = 2, - hidden_channels: int = 64, - norm: str | tuple= "INSTANCE", - norm_params: dict = {} - )-> None: + + def __init__( + self, + label_nc: int, + norm_nc: int, + kernel_size: int = 3, + spatial_dims: int = 2, + hidden_channels: int = 64, + norm: str | tuple = "INSTANCE", + norm_params: dict = {}, + ) -> None: super().__init__() if len(norm_params) != 0: norm = (norm, norm_params) - self.param_free_norm = ADN(act=None, dropout=0.0, norm = norm, - norm_dim=spatial_dims, - ordering="N", - in_channels=norm_nc) - self.mlp_shared = Convolution(spatial_dims=spatial_dims, - in_channels = label_nc, - out_channels = hidden_channels, - kernel_size= kernel_size, - norm = None, - padding=kernel_size//2, - act="LEAKYRELU") - self.mlp_gamma = Convolution(spatial_dims=spatial_dims, - in_channels=hidden_channels, - out_channels=norm_nc, - kernel_size=kernel_size, - padding = kernel_size//2, - act = None - ) - self.mlp_beta = Convolution(spatial_dims=spatial_dims, - in_channels=hidden_channels, - out_channels=norm_nc, - kernel_size=kernel_size, - padding = kernel_size//2, - act = None - ) - + self.param_free_norm = ADN( + act=None, dropout=0.0, norm=norm, norm_dim=spatial_dims, ordering="N", in_channels=norm_nc + ) + self.mlp_shared = Convolution( + spatial_dims=spatial_dims, + in_channels=label_nc, + out_channels=hidden_channels, + kernel_size=kernel_size, + norm=None, + padding=kernel_size // 2, + act="LEAKYRELU", + ) + self.mlp_gamma = Convolution( + spatial_dims=spatial_dims, + in_channels=hidden_channels, + out_channels=norm_nc, + kernel_size=kernel_size, + padding=kernel_size // 2, + act=None, + ) + self.mlp_beta = Convolution( + spatial_dims=spatial_dims, + in_channels=hidden_channels, + out_channels=norm_nc, + kernel_size=kernel_size, + padding=kernel_size // 2, + act=None, + ) - def forward(self, - x: torch.Tensor, - segmap: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, segmap: torch.Tensor) -> torch.Tensor: """ Args: x: input tensor segmap: input segmentation map (bxcx[spatial-dimensions]) where c is the number of semantic channels. The map will be interpolated to the dimension of x internally. - Returns: - """ # Part 1. generate parameter-free normalized activations normalized = self.param_free_norm(x) # Part 2. produce scaling and bias conditioned on semantic map - segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') + segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") actv = self.mlp_shared(segmap) gamma = self.mlp_gamma(actv) beta = self.mlp_beta(actv) diff --git a/generative/networks/nets/spade_network.py b/generative/networks/nets/spade_network.py index 0be1d9c8..b2ff2833 100644 --- a/generative/networks/nets/spade_network.py +++ b/generative/networks/nets/spade_network.py @@ -9,25 +9,32 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import Sequence, Union + +import numpy as np import torch import torch.nn as nn -from typing import Union, Sequence -import numpy as np +import torch.nn.functional as F from monai.networks.blocks import Convolution from monai.networks.layers import Act -from generative.networks.blocks.spade_norm import SPADE from monai.utils.enums import StrEnum -import torch.nn.functional as F + from generative.losses.kld_loss import KLDLoss +from generative.networks.blocks.spade_norm import SPADE + class UpsamplingModes(StrEnum): bicubic = "bicubic" nearest = "nearest" bilinear = "bilinear" + class SPADE_ResNetBlock(nn.Module): """ Creates a Residual Block with SPADE normalisation. + Args: spatial_dims: number of spatial dimensions in_channels: number of input channels @@ -38,51 +45,67 @@ class SPADE_ResNetBlock(nn.Module): kernel_size: convolutional kernel size """ - def __init__(self, - spatial_dims: int, - in_channels: int, - out_channels: int, - label_nc: int, - spade_intermediate_channels: int = 128, - norm: Union[str, tuple] = "INSTANCE", - kernel_size: int = 3,): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + label_nc: int, + spade_intermediate_channels: int = 128, + norm: Union[str, tuple] = "INSTANCE", + kernel_size: int = 3, + ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.int_channels = min(self.in_channels, self.out_channels) self.learned_shortcut = self.in_channels != self.out_channels - self.conv_0 = Convolution(spatial_dims = spatial_dims, - in_channels = self.in_channels, - out_channels = self.int_channels, - act = None, - norm = None, - ) - self.conv_1 = Convolution(spatial_dims = spatial_dims, - in_channels = self.int_channels, - out_channels = self.out_channels, - act = None, - norm = None, - ) + self.conv_0 = Convolution( + spatial_dims=spatial_dims, in_channels=self.in_channels, out_channels=self.int_channels, act=None, norm=None + ) + self.conv_1 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.int_channels, + out_channels=self.out_channels, + act=None, + norm=None, + ) self.activation = nn.LeakyReLU(0.2, False) - self.norm_0 = SPADE(label_nc=label_nc, norm_nc=self.in_channels, kernel_size=kernel_size, - spatial_dims=spatial_dims, hidden_channels=spade_intermediate_channels, - norm=norm) - self.norm_1 = SPADE(label_nc=label_nc, norm_nc=self.int_channels, kernel_size=kernel_size, - spatial_dims=spatial_dims, hidden_channels=spade_intermediate_channels, - norm=norm) + self.norm_0 = SPADE( + label_nc=label_nc, + norm_nc=self.in_channels, + kernel_size=kernel_size, + spatial_dims=spatial_dims, + hidden_channels=spade_intermediate_channels, + norm=norm, + ) + self.norm_1 = SPADE( + label_nc=label_nc, + norm_nc=self.int_channels, + kernel_size=kernel_size, + spatial_dims=spatial_dims, + hidden_channels=spade_intermediate_channels, + norm=norm, + ) if self.learned_shortcut: - self.conv_s = Convolution(spatial_dims = spatial_dims, - in_channels = self.in_channels, - out_channels = self.out_channels, - act = None, - norm = None, - kernel_size=1, - ) - self.norm_s = SPADE(label_nc=label_nc, norm_nc=self.in_channels, kernel_size=kernel_size, - spatial_dims=spatial_dims, hidden_channels=spade_intermediate_channels, - norm=norm) + self.conv_s = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + act=None, + norm=None, + kernel_size=1, + ) + self.norm_s = SPADE( + label_nc=label_nc, + norm_nc=self.in_channels, + kernel_size=kernel_size, + spatial_dims=spatial_dims, + hidden_channels=spade_intermediate_channels, + norm=norm, + ) def forward(self, x, seg): x_s = self.shortcut(x, seg) @@ -98,9 +121,11 @@ def shortcut(self, x, seg): x_s = x return x_s + class SPADE_Encoder(nn.Module): """ Encoding branch of a VAE compatible with a SPADE-like generator + Args: spatial_dims: number of spatial dimensions in_channels: number of input channels @@ -112,47 +137,58 @@ class SPADE_Encoder(nn.Module): norm: normalisation layer type act: activation type """ - def __init__(self, - spatial_dims: int, - in_channels: int, - z_dim: int, - num_channels: Sequence[int], - input_shape: Sequence[int], - kernel_size: int = 3, - norm: Union[str, tuple] = "INSTANCE", - act: Union[str, tuple] = (Act.LEAKYRELU, {"negative_slope": 0.2})): + + def __init__( + self, + spatial_dims: int, + in_channels: int, + z_dim: int, + num_channels: Sequence[int], + input_shape: Sequence[int], + kernel_size: int = 3, + norm: Union[str, tuple] = "INSTANCE", + act: Union[str, tuple] = (Act.LEAKYRELU, {"negative_slope": 0.2}), + ): super().__init__() self.in_channels = in_channels self.z_dim = z_dim self.num_channels = num_channels if len(input_shape) != spatial_dims: - raise ValueError("Length of parameter input shape must match spatial_dims; got %s" %(input_shape)) + raise ValueError("Length of parameter input shape must match spatial_dims; got %s" % (input_shape)) for s_ind, s_ in enumerate(input_shape): if s_ / (2 ** len(num_channels)) != s_ // (2 ** len(num_channels)): - raise ValueError("Each dimension of your input must be divisible by 2 ** (autoencoder depth)." - "The shape in position %d, %d is not divisible by %d. " %(s_ind, s_, len(num_channels))) + raise ValueError( + "Each dimension of your input must be divisible by 2 ** (autoencoder depth)." + "The shape in position %d, %d is not divisible by %d. " % (s_ind, s_, len(num_channels)) + ) self.input_shape = input_shape self.latent_spatial_shape = [s_ // (2 ** len(self.num_channels)) for s_ in self.input_shape] blocks = [] ch_init = self.in_channels for ch_ind, ch_value in enumerate(num_channels): - blocks.append(Convolution(spatial_dims = spatial_dims, - in_channels = ch_init, - out_channels= ch_value, - strides=2, - kernel_size=kernel_size, - norm = norm, - act = act)) + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=ch_init, + out_channels=ch_value, + strides=2, + kernel_size=kernel_size, + norm=norm, + act=act, + ) + ) ch_init = ch_value self.blocks = nn.ModuleList(blocks) - self.fc_mu = nn.Linear(in_features=np.prod(self.latent_spatial_shape) * self.num_channels[-1], - out_features=self.z_dim) - self.fc_var = nn.Linear(in_features=np.prod(self.latent_spatial_shape) * self.num_channels[-1], - out_features=self.z_dim) + self.fc_mu = nn.Linear( + in_features=np.prod(self.latent_spatial_shape) * self.num_channels[-1], out_features=self.z_dim + ) + self.fc_var = nn.Linear( + in_features=np.prod(self.latent_spatial_shape) * self.num_channels[-1], out_features=self.z_dim + ) - def forward(self, x,): + def forward(self, x): for block in self.blocks: x = block(x) x = x.view(x.size(0), -1) @@ -174,10 +210,12 @@ def reparameterize(self, mu, logvar): eps = torch.randn_like(std) return eps.mul(std) + mu + class SPADE_Decoder(nn.Module): """ Decoder branch of a SPADE-like generator. It can be used independently, without an encoding branch, behaving like a GAN, or coupled to a SPADE encoder. + Args: label_nc: number of semantic labels spatial_dims: number of spatial dimensions @@ -194,21 +232,23 @@ class SPADE_Decoder(nn.Module): kernel_size: convolutional kernel size upsampling_mode: upsampling mode (nearest, bilinear etc.) """ - def __init__(self, - spatial_dims: int, - out_channels: int, - label_nc: int, - input_shape: Sequence[int], - num_channels: Sequence[int], - z_dim: Union[int, None] = None, - is_gan: bool = False, - spade_intermediate_channels: int = 128, - norm: Union[str, tuple] = "INSTANCE", - act: Union[str, tuple, None] = (Act.LEAKYRELU, {"negative_slope": 0.2}), - last_act: Union[str, tuple, None] = (Act.LEAKYRELU, {"negative_slope": 0.2}), - kernel_size: int = 3, - upsampling_mode: str = UpsamplingModes.nearest.value, - ): + + def __init__( + self, + spatial_dims: int, + out_channels: int, + label_nc: int, + input_shape: Sequence[int], + num_channels: Sequence[int], + z_dim: Union[int, None] = None, + is_gan: bool = False, + spade_intermediate_channels: int = 128, + norm: Union[str, tuple] = "INSTANCE", + act: Union[str, tuple, None] = (Act.LEAKYRELU, {"negative_slope": 0.2}), + last_act: Union[str, tuple, None] = (Act.LEAKYRELU, {"negative_slope": 0.2}), + kernel_size: int = 3, + upsampling_mode: str = UpsamplingModes.nearest.value, + ): super().__init__() self.is_gan = is_gan @@ -219,9 +259,10 @@ def __init__(self, raise ValueError("Length of parameter input shape must match spatial_dims; got %s" % (input_shape)) for s_ind, s_ in enumerate(input_shape): if s_ / (2 ** len(num_channels)) != s_ // (2 ** len(num_channels)): - raise ValueError("Each dimension of your input must be divisible by 2 ** (autoencoder depth)." - "The shape in position %d, %d is not divisible by %d. " % ( - s_ind, s_, len(num_channels))) + raise ValueError( + "Each dimension of your input must be divisible by 2 ** (autoencoder depth)." + "The shape in position %d, %d is not divisible by %d. " % (s_ind, s_, len(num_channels)) + ) self.latent_spatial_shape = [s_ // (2 ** len(self.num_channels)) for s_ in input_shape] if self.is_gan: @@ -233,24 +274,28 @@ def __init__(self, num_channels.append(self.out_channels) self.upsampling = torch.nn.Upsample(scale_factor=2, mode=upsampling_mode) for ch_ind, ch_value in enumerate(num_channels[:-1]): - blocks.append(SPADE_ResNetBlock(spatial_dims=spatial_dims, - in_channels=ch_value, - out_channels=num_channels[ch_ind+1], - label_nc=label_nc, - spade_intermediate_channels=spade_intermediate_channels, - norm=norm, - kernel_size=kernel_size),) + blocks.append( + SPADE_ResNetBlock( + spatial_dims=spatial_dims, + in_channels=ch_value, + out_channels=num_channels[ch_ind + 1], + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + norm=norm, + kernel_size=kernel_size, + ) + ) self.blocks = torch.nn.ModuleList(blocks) - self.last_conv = Convolution(spatial_dims=spatial_dims, - in_channels=num_channels[-1], - out_channels=out_channels, - padding=(kernel_size-1)//2, - kernel_size=kernel_size, - norm = None, - act=last_act - ) - + self.last_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + out_channels=out_channels, + padding=(kernel_size - 1) // 2, + kernel_size=kernel_size, + norm=None, + act=last_act, + ) def forward(self, seg, z: torch.Tensor = None): if self.is_gan: @@ -258,10 +303,9 @@ def forward(self, seg, z: torch.Tensor = None): x = self.fc(x) else: if z is None: - z = torch.randn(seg.size(0), self.opt.z_dim, - dtype=torch.float32, device=seg.get_device()) + z = torch.randn(seg.size(0), self.opt.z_dim, dtype=torch.float32, device=seg.get_device()) x = self.fc(z) - x = x.view(*[-1, self.num_channels[0]]+self.latent_spatial_shape) + x = x.view(*[-1, self.num_channels[0]] + self.latent_spatial_shape) for res_block in self.blocks: x = res_block(x, seg) @@ -270,11 +314,13 @@ def forward(self, seg, z: torch.Tensor = None): x = self.last_conv(x) return x + class SPADE_Net(nn.Module): """ SPADE Network, implemented based on the code by Park, T et al. in "Semantic Image Synthesis with Spatially-Adaptive Normalization" (https://github.com/NVlabs/SPADE) + Args: spatial_dims: number of spatial dimensions in_channels: number of input channels @@ -293,22 +339,21 @@ class SPADE_Net(nn.Module): """ def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - label_nc: int, - input_shape: Sequence[int], - num_channels: Sequence[int], - z_dim: Union[int, None] = None, - is_vae: bool = True, - spade_intermediate_channels: int = 128, - norm:Union[str, tuple] = "INSTANCE", - act: Union[str, tuple, None] = (Act.LEAKYRELU, {"negative_slope": 0.2}), - last_act: Union[str, tuple, None] = (Act.LEAKYRELU, {"negative_slope": 0.2}), - kernel_size: int = 3, - upsampling_mode: str = UpsamplingModes.nearest.value - + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + label_nc: int, + input_shape: Sequence[int], + num_channels: Sequence[int], + z_dim: Union[int, None] = None, + is_vae: bool = True, + spade_intermediate_channels: int = 128, + norm: Union[str, tuple] = "INSTANCE", + act: Union[str, tuple, None] = (Act.LEAKYRELU, {"negative_slope": 0.2}), + last_act: Union[str, tuple, None] = (Act.LEAKYRELU, {"negative_slope": 0.2}), + kernel_size: int = 3, + upsampling_mode: str = UpsamplingModes.nearest.value, ): super().__init__() @@ -325,14 +370,15 @@ def __init__( if self.is_vae: self.encoder = SPADE_Encoder( - spatial_dims = spatial_dims, - in_channels = in_channels, - z_dim = z_dim, - num_channels = num_channels, - input_shape = input_shape, - kernel_size = kernel_size, - norm = norm, - act = act) + spatial_dims=spatial_dims, + in_channels=in_channels, + z_dim=z_dim, + num_channels=num_channels, + input_shape=input_shape, + kernel_size=kernel_size, + norm=norm, + act=act, + ) decoder_channels = num_channels decoder_channels.reverse() @@ -342,15 +388,15 @@ def __init__( out_channels=out_channels, label_nc=label_nc, input_shape=input_shape, - num_channels= decoder_channels, - z_dim = z_dim, - is_gan = not is_vae, - spade_intermediate_channels = spade_intermediate_channels, - norm = norm, - act = act, - last_act = last_act, + num_channels=decoder_channels, + z_dim=z_dim, + is_gan=not is_vae, + spade_intermediate_channels=spade_intermediate_channels, + norm=norm, + act=act, + last_act=last_act, kernel_size=kernel_size, - upsampling_mode=upsampling_mode + upsampling_mode=upsampling_mode, ) def forward(self, seg: torch.Tensor, x: Union[torch.Tensor, None] = None): @@ -361,7 +407,7 @@ def forward(self, seg: torch.Tensor, x: Union[torch.Tensor, None] = None): kld_loss = self.kld_loss(z_mu, z_logvar) return self.decoder(seg, z), kld_loss else: - return self.decoder(seg, z), + return (self.decoder(seg, z),) def encode(self, x: torch.Tensor): diff --git a/tests/test_spade_vaegan.py b/tests/test_spade_vaegan.py index 3354d4d6..e030b81e 100644 --- a/tests/test_spade_vaegan.py +++ b/tests/test_spade_vaegan.py @@ -10,79 +10,91 @@ # limitations under the License. from __future__ import annotations + import unittest + +import numpy as np import torch from monai.networks import eval_mode from parameterized import parameterized -from networks.nets.spade_network import SPADE_Net -import numpy as np + +from generative.networks.nets.spade_network import SPADE_Net CASE_2D = [[[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], 16, True]]] CASE_2D_BIS = [[[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], 16, True]]] CASE_3D = [[[3, 1, 1, 3, [64, 64, 64], [16, 32, 64, 128], 16, True]]] -def create_Semantic_Data(shape:list, semantic_regions:int): - ''' + +def create_semantic_data(shape: list, semantic_regions: int): + """ To create semantic and image mock inputs for the network. Args: shape: input shape semantic_regions: number of semantic regions Returns: - ''' + """ out_label = torch.zeros(shape) - out_image = torch.zeros(shape) + torch.randn(shape)*0.01 + out_image = torch.zeros(shape) + torch.randn(shape) * 0.01 for i in range(1, semantic_regions): - shape_square = [i//np.random.choice(list(range(2, i//2))) for i in shape] - start_point = [np.random.choice(list(range(shape[ind]-shape_square[ind]))) - for ind, i in enumerate(shape)] + shape_square = [i // np.random.choice(list(range(2, i // 2))) for i in shape] + start_point = [np.random.choice(list(range(shape[ind] - shape_square[ind]))) for ind, i in enumerate(shape)] if len(shape) == 2: - out_label[start_point[0]:(start_point[0]+shape_square[0]), - start_point[1]:(start_point[1]+shape_square[1])] = i + out_label[ + start_point[0] : (start_point[0] + shape_square[0]), start_point[1] : (start_point[1] + shape_square[1]) + ] = i base_intensity = torch.ones(shape_square) * np.random.randn() - out_image[start_point[0]:(start_point[0] + shape_square[0]), - start_point[1]:(start_point[1] + shape_square[1])] = base_intensity + \ - torch.randn(shape_square)*0.1 + out_image[ + start_point[0] : (start_point[0] + shape_square[0]), start_point[1] : (start_point[1] + shape_square[1]) + ] = (base_intensity + torch.randn(shape_square) * 0.1) elif len(shape) == 3: - out_label[start_point[0]:(start_point[0]+shape_square[0]), - start_point[1]:(start_point[1]+shape_square[1]), - start_point[2]:(start_point[2] + shape_square[2])] = i + out_label[ + start_point[0] : (start_point[0] + shape_square[0]), + start_point[1] : (start_point[1] + shape_square[1]), + start_point[2] : (start_point[2] + shape_square[2]), + ] = i base_intensity = torch.ones(shape_square) * np.random.randn() - out_image[start_point[0]:(start_point[0]+shape_square[0]), - start_point[1]:(start_point[1]+shape_square[1]), - start_point[2]:(start_point[2] + shape_square[2])] = base_intensity + \ - torch.randn(shape_square)*0.1 + out_image[ + start_point[0] : (start_point[0] + shape_square[0]), + start_point[1] : (start_point[1] + shape_square[1]), + start_point[2] : (start_point[2] + shape_square[2]), + ] = (base_intensity + torch.randn(shape_square) * 0.1) else: ValueError("Supports only 2D and 3D tensors") # One hot encode label - out_label_ = torch.zeros([semantic_regions,] + list(out_label.shape)) + out_label_ = torch.zeros([semantic_regions] + list(out_label.shape)) for ch in range(semantic_regions): out_label_[ch, ...] = out_label == ch return out_label_.unsqueeze(0), out_image.unsqueeze(0).unsqueeze(0) -class TestDiffusionModelUNet2D(unittest.TestCase): +class TestDiffusionModelUNet2D(unittest.TestCase): @parameterized.expand(CASE_2D) def test_forward_2d(self, input_param): - ''' + """ Check that forward method is called correctly and output shape matches. - ''' + """ net = SPADE_Net(*input_param) - in_label, in_image = create_Semantic_Data(input_param[4], input_param[3]) + in_label, in_image = create_semantic_data(input_param[4], input_param[3]) with eval_mode(net): out, kld = net(in_label, in_image) - self.assertEqual(False, True in torch.isnan(out) or True in torch.isinf(out) - or True in torch.isinf(kld) or True in torch.isinf(kld)) + self.assertEqual( + False, + True in torch.isnan(out) + or True in torch.isinf(out) + or True in torch.isinf(kld) + or True in torch.isinf(kld), + ) self.assertEqual(list(out.shape), [1, 1, 64, 64]) @parameterized.expand(CASE_2D_BIS) def test_encoder_decoder(self, input_param): - ''' + """ Check that forward method is called correctly and output shape matches. - ''' + """ net = SPADE_Net(*input_param) - in_label, in_image = create_Semantic_Data(input_param[4], input_param[3]) + in_label, in_image = create_semantic_data(input_param[4], input_param[3]) with eval_mode(net): out_z = net.encode(in_image) self.assertEqual(list(out_z.shape), [1, 16]) @@ -91,23 +103,29 @@ def test_encoder_decoder(self, input_param): @parameterized.expand(CASE_3D) def test_forward_3d(self, input_param): - ''' + """ Check that forward method is called correctly and output shape matches. - ''' + """ net = SPADE_Net(*input_param) - in_label, in_image = create_Semantic_Data(input_param[4], input_param[3]) + in_label, in_image = create_semantic_data(input_param[4], input_param[3]) with eval_mode(net): out, kld = net(in_label, in_image) - self.assertEqual(False, True in torch.isnan(out) or True in torch.isinf(out) - or True in torch.isinf(kld) or True in torch.isinf(kld)) + self.assertEqual( + False, + True in torch.isnan(out) + or True in torch.isinf(out) + or True in torch.isinf(kld) + or True in torch.isinf(kld), + ) self.assertEqual(list(out.shape), [1, 1, 64, 64, 64]) def test_shape_wrong(self): - ''' + """ We input an input shape that isn't divisible by 2**(n downstream steps) - ''' + """ with self.assertRaises(ValueError): net = SPADE_Net(1, 1, 8, [16, 16], [16, 32, 64, 128], 16, True) + if __name__ == "__main__": unittest.main() diff --git a/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb b/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb index e5c64a3d..9f5333b6 100644 --- a/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb +++ b/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb @@ -65,9 +65,7 @@ "cell_type": "code", "execution_count": 2, "id": "e76296e7", - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -357,14 +355,6 @@ " plt.show()" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "eaa62145", - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": 12, @@ -427,9 +417,7 @@ "cell_type": "code", "execution_count": 15, "id": "36ea4308", - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [ { "name": "stderr", @@ -472,7 +460,7 @@ "execution_count": 27, "id": "918eac0a", "metadata": { - "scrolled": false + "lines_to_next_cell": 2 }, "outputs": [ { @@ -2339,8 +2327,7 @@ " for key, val in losses_epoch.items():\n", " losses[key].append(val / len(train_loader))\n", " for key, val in losses_epoch_val.items():\n", - " losses_val[key].append(val / len(val_loader))\n", - " " + " losses_val[key].append(val / len(val_loader))" ] }, { @@ -2380,6 +2367,7 @@ "cell_type": "markdown", "id": "c3cf096f", "metadata": { + "lines_to_next_cell": 0, "pycharm": { "name": "#%%" } @@ -2387,26 +2375,6 @@ "source": [ "**Conclusion**: from early on, the network shows the capability of discern between the different semantic layers. To achieve good image quality, more images and training time are needed (to avoid overfitting, seen in some loss plots of previous example), as well as thorough optimisation, such as establishing an adversarial schedule that makes sure that the discriminator and generator and the discriminator are trained only when their performance does not exceed a certain limit.\n" ] - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8cf9ab5f", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -2428,9 +2396,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.8.13" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/tutorials/generative/2d_spade_gan/2d_spade_vae.py b/tutorials/generative/2d_spade_gan/2d_spade_vae.py index 3bb1f04f..ec286f29 100644 --- a/tutorials/generative/2d_spade_gan/2d_spade_vae.py +++ b/tutorials/generative/2d_spade_gan/2d_spade_vae.py @@ -49,7 +49,7 @@ directory = os.environ.get("MONAI_DATA_DIRECTORY") root_dir = tempfile.mkdtemp() if directory is None else directory root_dir = Path(root_dir) -print("Temporary directory used: %s " %root_dir) +print("Temporary directory used: %s " % root_dir) # INPUT PARAMETERS input_shape = [128, 128] @@ -80,58 +80,63 @@ # [3] Fiford CM, Sudre CH, Pemberton H, Walsh P, Manning E, Malone IB, Nicholas J, Bouvy WH, Carmichael OT, Biessels GJ, Cardoso MJ, Barnes J; Alzheimer’s Disease Neuroimaging Initiative. Automated White Matter Hyperintensity Segmentation Using Bayesian Model Selection: Assessment and Correlations with Cognitive Change. Neuroinformatics. 2020 Jun;18(3):429-449. doi: 10.1007/s12021-019-09439-6. PMID: 32062817; PMCID: PMC7338814. # -gdown.download("https://drive.google.com/uc?export=download&id=1SX_MCzQe-vyq09QYxECk32wZ2vxp9rx5", - str(root_dir / 'data.zip')) +gdown.download( + "https://drive.google.com/uc?export=download&id=1SX_MCzQe-vyq09QYxECk32wZ2vxp9rx5", str(root_dir / "data.zip") +) -zip_obj = zipfile.ZipFile(os.path.join(root_dir, 'data.zip'), 'r') +zip_obj = zipfile.ZipFile(os.path.join(root_dir, "data.zip"), "r") zip_obj.extractall(root_dir) images_T1 = root_dir / "OASIS_SMALL-SUBSET/T1" images_FLAIR = root_dir / "OASIS_SMALL-SUBSET/FLAIR" labels = root_dir / "OASIS_SMALL-SUBSET/Segmentations" # We create the data dictionaries that we need -all_images = [os.path.join(images_T1, i) for i in os.listdir(images_T1)] + \ - [os.path.join(images_FLAIR, i) for i in os.listdir(images_FLAIR)] +all_images = [os.path.join(images_T1, i) for i in os.listdir(images_T1)] + [ + os.path.join(images_FLAIR, i) for i in os.listdir(images_FLAIR) +] np.random.shuffle(all_images) -corresponding_labels = [os.path.join(labels, i.split("/")[-1].replace(i.split("/")[-1].split("_")[0], "Parcellation")) - for i in all_images] -input_dict = [{'image': i, 'label': corresponding_labels[ind]} for ind, i in enumerate(all_images)] -input_dict_train = input_dict[:int(len(input_dict)*0.9)] -input_dict_val = input_dict[int(len(input_dict)*0.9):] +corresponding_labels = [ + os.path.join(labels, i.split("/")[-1].replace(i.split("/")[-1].split("_")[0], "Parcellation")) for i in all_images +] +input_dict = [{"image": i, "label": corresponding_labels[ind]} for ind, i in enumerate(all_images)] +input_dict_train = input_dict[: int(len(input_dict) * 0.9)] +input_dict_val = input_dict[int(len(input_dict) * 0.9) :] # ### Dataloaders # + -preliminar_shape = input_shape + [50] # We take random slices fron the center of the brain +preliminar_shape = input_shape + [50] # We take random slices fron the center of the brain crop_shape = input_shape + [1] base_transforms = [ - monai.transforms.LoadImaged(keys = ['label', 'image']), - monai.transforms.EnsureChannelFirstd(keys=['image', 'label']), - monai.transforms.CenterSpatialCropd(keys=['label', 'image'], - roi_size=preliminar_shape), - monai.transforms.RandSpatialCropd(keys = ['label', 'image'], - roi_size=crop_shape, max_roi_size=crop_shape), - monai.transforms.SqueezeDimd(keys=['label', 'image'], dim = -1), - monai.transforms.Resized(keys = ['image', 'label'], spatial_size=input_shape), + monai.transforms.LoadImaged(keys=["label", "image"]), + monai.transforms.EnsureChannelFirstd(keys=["image", "label"]), + monai.transforms.CenterSpatialCropd(keys=["label", "image"], roi_size=preliminar_shape), + monai.transforms.RandSpatialCropd(keys=["label", "image"], roi_size=crop_shape, max_roi_size=crop_shape), + monai.transforms.SqueezeDimd(keys=["label", "image"], dim=-1), + monai.transforms.Resized(keys=["image", "label"], spatial_size=input_shape), ] last_transforms = [ - monai.transforms.CopyItemsd(keys=['label'], names=['label_channel']), - monai.transforms.Lambdad(keys=['label_channel'], - func=lambda l: l != 0), - monai.transforms.MaskIntensityd(keys=['image'], mask_key='label_channel'), - monai.transforms.NormalizeIntensityd(keys=['image']), - monai.transforms.ToTensord(keys=['image', 'label']) - ] + monai.transforms.CopyItemsd(keys=["label"], names=["label_channel"]), + monai.transforms.Lambdad(keys=["label_channel"], func=lambda l: l != 0), + monai.transforms.MaskIntensityd(keys=["image"], mask_key="label_channel"), + monai.transforms.NormalizeIntensityd(keys=["image"]), + monai.transforms.ToTensord(keys=["image", "label"]), +] aug_transforms = [ - monai.transforms.RandBiasFieldd(coeff_range=(0, 0.005), prob=0.33, keys=['image']), - monai.transforms.RandAdjustContrastd(gamma=(0.9, 1.15), prob=0.33, keys=['image']), - monai.transforms.RandGaussianNoised(prob=0.33, mean=0.0, std=np.random.uniform(0.005, 0.015), - keys=['image']), - monai.transforms.RandAffined(rotate_range=[-0.05, 0.05], shear_range=[0.001, 0.05], - scale_range=[0, 0.05], padding_mode='zeros', - mode='nearest', prob=0.33, keys=['label', 'image']) - ] + monai.transforms.RandBiasFieldd(coeff_range=(0, 0.005), prob=0.33, keys=["image"]), + monai.transforms.RandAdjustContrastd(gamma=(0.9, 1.15), prob=0.33, keys=["image"]), + monai.transforms.RandGaussianNoised(prob=0.33, mean=0.0, std=np.random.uniform(0.005, 0.015), keys=["image"]), + monai.transforms.RandAffined( + rotate_range=[-0.05, 0.05], + shear_range=[0.001, 0.05], + scale_range=[0, 0.05], + padding_mode="zeros", + mode="nearest", + prob=0.33, + keys=["label", "image"], + ), +] train_transforms = monai.transforms.Compose(base_transforms + aug_transforms + last_transforms) val_transforms = monai.transforms.Compose(base_transforms + last_transforms) @@ -146,11 +151,13 @@ # Sanity check batch = next(iter(train_loader)) -print(batch['image'].shape) -plt.subplot(1,2,1) -plt.imshow(batch['image'][0,0,...], cmap = 'gist_gray'); plt.axis('off') -plt.subplot(1,2,2) -plt.imshow(batch['label'][0,0,...], cmap = "jet"); plt.axis('off') +print(batch["image"].shape) +plt.subplot(1, 2, 1) +plt.imshow(batch["image"][0, 0, ...], cmap="gist_gray") +plt.axis("off") +plt.subplot(1, 2, 2) +plt.imshow(batch["label"][0, 0, ...], cmap="jet") +plt.axis("off") plt.show() # ### Network creation and losses @@ -168,25 +175,23 @@ def one_hot(input_label, label_nc): return label_out - def picture_results(input_label, input_image, output_image): - f = plt.figure(figsize = (4, 1.5)) - plt.subplot(1,3,1) - plt.imshow(torch.argmax(input_label, 1)[0,...].detach().cpu(), cmap = 'jet') - plt.axis('off') + f = plt.figure(figsize=(4, 1.5)) + plt.subplot(1, 3, 1) + plt.imshow(torch.argmax(input_label, 1)[0, ...].detach().cpu(), cmap="jet") + plt.axis("off") plt.title("Label") - plt.subplot(1,3,2) - plt.imshow(input_image[0,0,...].detach().cpu(), cmap = 'gist_gray') - plt.axis('off') + plt.subplot(1, 3, 2) + plt.imshow(input_image[0, 0, ...].detach().cpu(), cmap="gist_gray") + plt.axis("off") plt.title("Input image") - plt.subplot(1,3,3) - plt.imshow(output_image[0,0,...].detach().cpu(), cmap = 'gist_gray') - plt.axis('off') + plt.subplot(1, 3, 3) + plt.imshow(output_image[0, 0, ...].detach().cpu(), cmap="gist_gray") + plt.axis("off") plt.title("Output image") plt.show() - def feature_loss(input_features_disc_fake, input_features_disc_real, lambda_feat, device): criterion = torch.nn.L1Loss() num_D = len(input_features_disc_fake) @@ -194,44 +199,43 @@ def feature_loss(input_features_disc_fake, input_features_disc_real, lambda_feat for i in range(num_D): # for each discriminator num_intermediate_outputs = len(input_features_disc_fake[i]) for j in range(num_intermediate_outputs): # for each layer output - unweighted_loss = criterion(input_features_disc_fake[i][j], - input_features_disc_real[i][j].detach()) + unweighted_loss = criterion(input_features_disc_fake[i][j], input_features_disc_real[i][j].detach()) GAN_Feat_loss += unweighted_loss * lambda_feat / num_D return GAN_Feat_loss -net = SPADE_Net(spatial_dims = 2, - in_channels = 1, - out_channels = 1, - label_nc = 6, - input_shape = input_shape, - num_channels = [16, 32, 64, 128], - z_dim = 16, - is_vae = True) +net = SPADE_Net( + spatial_dims=2, + in_channels=1, + out_channels=1, + label_nc=6, + input_shape=input_shape, + num_channels=[16, 32, 64, 128], + z_dim=16, + is_vae=True, +) # + -discriminator = MultiScalePatchDiscriminator(num_d = 2, - num_layers_d = 3, - spatial_dims = 2, - num_channels = 8, - in_channels = 7, - out_channels = 7, - minimum_size_im = 128, - norm = "INSTANCE", - kernel_size = 3 - ) - -adversarial_loss = PatchAdversarialLoss(reduction = "sum", criterion = "hinge") +discriminator = MultiScalePatchDiscriminator( + num_d=2, + num_layers_d=3, + spatial_dims=2, + num_channels=8, + in_channels=7, + out_channels=7, + minimum_size_im=128, + norm="INSTANCE", + kernel_size=3, +) + +adversarial_loss = PatchAdversarialLoss(reduction="sum", criterion="hinge") # - -perceptual_loss = PerceptualLoss(spatial_dims = 2, - network_type = "vgg", - is_fake_3d = False, - pretrained = True) -perceptual_loss=perceptual_loss.to(device) +perceptual_loss = PerceptualLoss(spatial_dims=2, network_type="vgg", is_fake_3d=False, pretrained=True) +perceptual_loss = perceptual_loss.to(device) -optimizer_G = torch.optim.Adam(net.parameters(), lr = 0.0002) -optimizer_D = torch.optim.Adam(discriminator.parameters(), lr = 0.0004) +optimizer_G = torch.optim.Adam(net.parameters(), lr=0.0002) +optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0004) # ### Training loop # @@ -239,92 +243,96 @@ def feature_loss(input_features_disc_fake, input_features_disc_real, lambda_feat net = net.to(device) discriminator = discriminator.to(device) torch.autograd.set_detect_anomaly(True) -losses = {'kld': [], 'perceptual': [], 'feature': [], 'generator': [], 'discriminator': []} -losses_val = {'kld': [], 'perceptual': [], 'feature': [], 'generator': [], 'discriminator': []} +losses = {"kld": [], "perceptual": [], "feature": [], "generator": [], "discriminator": []} +losses_val = {"kld": [], "perceptual": [], "feature": [], "generator": [], "discriminator": []} for epoch in range(num_epochs): - print("Epoch %d/%d" %(epoch, num_epochs)) + print("Epoch %d/%d" % (epoch, num_epochs)) train_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=120) - losses_epoch = {'kld': 0, 'perceptual': 0, 'feature': 0, 'generator': 0, 'discriminator': 0} + losses_epoch = {"kld": 0, "perceptual": 0, "feature": 0, "generator": 0, "discriminator": 0} for step, d in train_bar: - image = d['image'].to(device) + image = d["image"].to(device) with torch.no_grad(): - label = one_hot(d['label'], 6).to(device) + label = one_hot(d["label"], 6).to(device) optimizer_G.zero_grad() # Losses gen out, kld_loss = net(label, image) disc_fakes, features_fakes = discriminator(torch.cat([out, label], 1)) - loss_g = adversarial_loss(disc_fakes, target_is_real = True, for_discriminator = False) + loss_g = adversarial_loss(disc_fakes, target_is_real=True, for_discriminator=False) disc_reals, features_reals = discriminator(torch.cat([image, label], 1)) loss_feat = feature_loss(features_fakes, features_reals, lambda_feat, device) - loss_perc = perceptual_loss(out, target = image) - total_loss = loss_g + loss_perc * lambda_perc + kld_loss * lambda_kld + loss_feat * lambda_feat - total_loss.backward(retain_graph = True) + loss_perc = perceptual_loss(out, target=image) + total_loss = loss_g + loss_perc * lambda_perc + kld_loss * lambda_kld + loss_feat * lambda_feat + total_loss.backward(retain_graph=True) optimizer_G.step() # Store - losses_epoch['kld'] += kld_loss.item() - losses_epoch['perceptual'] += loss_perc.item() - losses_epoch['generator'] += loss_g.item() - #Train disc + losses_epoch["kld"] += kld_loss.item() + losses_epoch["perceptual"] += loss_perc.item() + losses_epoch["generator"] += loss_g.item() + # Train disc out, _ = net(label, image) disc_fakes, _ = discriminator(torch.cat([out, label], 1)) - loss_d_r = adversarial_loss(disc_reals, target_is_real = True, for_discriminator = True) - loss_g_f = adversarial_loss(disc_fakes, target_is_real = False, for_discriminator = True) + loss_d_r = adversarial_loss(disc_reals, target_is_real=True, for_discriminator=True) + loss_g_f = adversarial_loss(disc_fakes, target_is_real=False, for_discriminator=True) optimizer_D.zero_grad() loss_d = loss_d_r + loss_g_f loss_d.backward() optimizer_D.step() # Store - losses_epoch['feature'] = loss_feat.item() - losses_epoch['discriminator'] = loss_d_r.item() + loss_g_f.item() + losses_epoch["feature"] = loss_feat.item() + losses_epoch["discriminator"] = loss_d_r.item() + loss_g_f.item() train_bar.set_postfix( - {"kld": kld_loss.item(), - "perceptual": loss_perc.item(), - "generator": loss_g.item(), - "feature": loss_feat.item(), - "discriminator": loss_d_r.item() + loss_g_f.item(), - }) + { + "kld": kld_loss.item(), + "perceptual": loss_perc.item(), + "generator": loss_g.item(), + "feature": loss_feat.item(), + "discriminator": loss_d_r.item() + loss_g_f.item(), + } + ) val_bar = tqdm(enumerate(val_loader), total=len(val_loader), ncols=120) - losses_epoch_val = {'kld': 0, 'perceptual': 0, 'feature': 0, 'generator': 0, 'discriminator': 0} + losses_epoch_val = {"kld": 0, "perceptual": 0, "feature": 0, "generator": 0, "discriminator": 0} for step, d in val_bar: - image = d['image'].to(device) + image = d["image"].to(device) with torch.no_grad(): - label = one_hot(d['label'], 6).to(device) + label = one_hot(d["label"], 6).to(device) # Losses gen out, kld_loss = net(label, image) disc_fakes, features_fakes = discriminator(torch.cat([out, label], 1)) - loss_g = adversarial_loss(disc_fakes, target_is_real = True, for_discriminator = False) + loss_g = adversarial_loss(disc_fakes, target_is_real=True, for_discriminator=False) disc_reals, features_reals = discriminator(torch.cat([image, label], 1)) loss_feat = feature_loss(features_fakes, features_reals, lambda_feat, device) - loss_perc = perceptual_loss(out, target = image) - total_loss = loss_adv * loss_g + loss_perc * lambda_perc + kld_loss * lambda_kld + loss_feat * lambda_feat + loss_perc = perceptual_loss(out, target=image) + total_loss = loss_adv * loss_g + loss_perc * lambda_perc + kld_loss * lambda_kld + loss_feat * lambda_feat # Store - losses_epoch_val['kld'] += kld_loss.item() - losses_epoch_val['perceptual'] += loss_perc.item() - losses_epoch_val['generator'] += loss_g.item() - #Train disc + losses_epoch_val["kld"] += kld_loss.item() + losses_epoch_val["perceptual"] += loss_perc.item() + losses_epoch_val["generator"] += loss_g.item() + # Train disc out, _ = net(label, image) disc_fakes, _ = discriminator(torch.cat([out, label], 1)) - loss_d_r = adversarial_loss(disc_reals, target_is_real = True, for_discriminator = True) - loss_g_f = adversarial_loss(disc_fakes, target_is_real = False, for_discriminator = True) + loss_d_r = adversarial_loss(disc_reals, target_is_real=True, for_discriminator=True) + loss_g_f = adversarial_loss(disc_fakes, target_is_real=False, for_discriminator=True) loss_d = loss_adv * (loss_d_r + loss_g_f) # Store - losses_epoch_val['feature'] = loss_feat.item() - losses_epoch_val['discriminator'] = loss_d_r.item() + loss_g_f.item() + losses_epoch_val["feature"] = loss_feat.item() + losses_epoch_val["discriminator"] = loss_d_r.item() + loss_g_f.item() val_bar.set_postfix( - {"kld": kld_loss.item(), - "perceptual": loss_perc.item(), - "generator": loss_g.item(), - "feature": loss_feat.item(), - "discriminator": loss_d_r.item() + loss_g_f.item(), - }) - if step == 0 and epoch%10==0: + { + "kld": kld_loss.item(), + "perceptual": loss_perc.item(), + "generator": loss_g.item(), + "feature": loss_feat.item(), + "discriminator": loss_d_r.item() + loss_g_f.item(), + } + ) + if step == 0 and epoch % 10 == 0: picture_results(label, image, out) for key, val in losses_epoch.items(): losses[key].append(val / len(train_loader)) @@ -333,16 +341,16 @@ def feature_loss(input_features_disc_fake, input_features_disc_real, lambda_feat # Plot losses -colors = ['orangered', 'royalblue', 'hotpink', 'lime', 'goldenrod'] -plt.figure(figsize=(5,10)) +colors = ["orangered", "royalblue", "hotpink", "lime", "goldenrod"] +plt.figure(figsize=(5, 10)) ind = 0 for key, val in losses.items(): - plt.subplot(len(losses.keys()),1,ind+1) - plt.plot(val, color = colors[ind], linestyle = '-') - plt.plot(losses_val[key], color = colors[ind], linestyle = '--') + plt.subplot(len(losses.keys()), 1, ind + 1) + plt.plot(val, color=colors[ind], linestyle="-") + plt.plot(losses_val[key], color=colors[ind], linestyle="--") plt.title(key) plt.xlabel("Epochs") - ind+=1; + ind += 1 plt.tight_layout() plt.show() @@ -350,5 +358,3 @@ def feature_loss(input_features_disc_fake, input_features_disc_real, lambda_feat # **Conclusion**: from early on, the network shows the capability of discern between the different semantic layers. To achieve good image quality, more images and training time are needed (to avoid overfitting, seen in some loss plots of previous example), as well as thorough optimisation, such as establishing an adversarial schedule that makes sure that the discriminator and generator and the discriminator are trained only when their performance does not exceed a certain limit. # # - - - From 1b421954a9f08070c5d1aa3231abf85092ca8627 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 5 Jul 2023 10:25:08 -0600 Subject: [PATCH 7/7] Fix tutorial formatting --- .../2d_spade_gan/2d_spade_vae.ipynb | 277 +++++++++--------- 1 file changed, 139 insertions(+), 138 deletions(-) diff --git a/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb b/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb index 9f5333b6..125ecd8b 100644 --- a/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb +++ b/tutorials/generative/2d_spade_gan/2d_spade_vae.ipynb @@ -79,7 +79,7 @@ "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n", "root_dir = tempfile.mkdtemp() if directory is None else directory\n", "root_dir = Path(root_dir)\n", - "print(\"Temporary directory used: %s \" %root_dir)" + "print(\"Temporary directory used: %s \" % root_dir)" ] }, { @@ -160,8 +160,9 @@ } ], "source": [ - "gdown.download(\"https://drive.google.com/uc?export=download&id=1SX_MCzQe-vyq09QYxECk32wZ2vxp9rx5\",\n", - " str(root_dir / 'data.zip'))" + "gdown.download(\n", + " \"https://drive.google.com/uc?export=download&id=1SX_MCzQe-vyq09QYxECk32wZ2vxp9rx5\", str(root_dir / \"data.zip\")\n", + ")" ] }, { @@ -171,7 +172,7 @@ "metadata": {}, "outputs": [], "source": [ - "zip_obj = zipfile.ZipFile(os.path.join(root_dir, 'data.zip'), 'r')\n", + "zip_obj = zipfile.ZipFile(os.path.join(root_dir, \"data.zip\"), \"r\")\n", "zip_obj.extractall(root_dir)\n", "images_T1 = root_dir / \"OASIS_SMALL-SUBSET/T1\"\n", "images_FLAIR = root_dir / \"OASIS_SMALL-SUBSET/FLAIR\"\n", @@ -186,14 +187,16 @@ "outputs": [], "source": [ "# We create the data dictionaries that we need\n", - "all_images = [os.path.join(images_T1, i) for i in os.listdir(images_T1)] + \\\n", - " [os.path.join(images_FLAIR, i) for i in os.listdir(images_FLAIR)]\n", + "all_images = [os.path.join(images_T1, i) for i in os.listdir(images_T1)] + [\n", + " os.path.join(images_FLAIR, i) for i in os.listdir(images_FLAIR)\n", + "]\n", "np.random.shuffle(all_images)\n", - "corresponding_labels = [os.path.join(labels, i.split(\"/\")[-1].replace(i.split(\"/\")[-1].split(\"_\")[0], \"Parcellation\"))\n", - " for i in all_images]\n", - "input_dict = [{'image': i, 'label': corresponding_labels[ind]} for ind, i in enumerate(all_images)]\n", - "input_dict_train = input_dict[:int(len(input_dict)*0.9)]\n", - "input_dict_val = input_dict[int(len(input_dict)*0.9):]" + "corresponding_labels = [\n", + " os.path.join(labels, i.split(\"/\")[-1].replace(i.split(\"/\")[-1].split(\"_\")[0], \"Parcellation\")) for i in all_images\n", + "]\n", + "input_dict = [{\"image\": i, \"label\": corresponding_labels[ind]} for ind, i in enumerate(all_images)]\n", + "input_dict_train = input_dict[: int(len(input_dict) * 0.9)]\n", + "input_dict_val = input_dict[int(len(input_dict) * 0.9) :]" ] }, { @@ -213,36 +216,38 @@ }, "outputs": [], "source": [ - "preliminar_shape = input_shape + [50] # We take random slices fron the center of the brain\n", + "preliminar_shape = input_shape + [50] # We take random slices fron the center of the brain\n", "crop_shape = input_shape + [1]\n", "base_transforms = [\n", - " monai.transforms.LoadImaged(keys = ['label', 'image']),\n", - " monai.transforms.EnsureChannelFirstd(keys=['image', 'label']),\n", - " monai.transforms.CenterSpatialCropd(keys=['label', 'image'],\n", - " roi_size=preliminar_shape),\n", - " monai.transforms.RandSpatialCropd(keys = ['label', 'image'],\n", - " roi_size=crop_shape, max_roi_size=crop_shape),\n", - " monai.transforms.SqueezeDimd(keys=['label', 'image'], dim = -1),\n", - " monai.transforms.Resized(keys = ['image', 'label'], spatial_size=input_shape),\n", + " monai.transforms.LoadImaged(keys=[\"label\", \"image\"]),\n", + " monai.transforms.EnsureChannelFirstd(keys=[\"image\", \"label\"]),\n", + " monai.transforms.CenterSpatialCropd(keys=[\"label\", \"image\"], roi_size=preliminar_shape),\n", + " monai.transforms.RandSpatialCropd(keys=[\"label\", \"image\"], roi_size=crop_shape, max_roi_size=crop_shape),\n", + " monai.transforms.SqueezeDimd(keys=[\"label\", \"image\"], dim=-1),\n", + " monai.transforms.Resized(keys=[\"image\", \"label\"], spatial_size=input_shape),\n", "]\n", "last_transforms = [\n", - " monai.transforms.CopyItemsd(keys=['label'], names=['label_channel']),\n", - " monai.transforms.Lambdad(keys=['label_channel'],\n", - " func=lambda l: l != 0),\n", - " monai.transforms.MaskIntensityd(keys=['image'], mask_key='label_channel'),\n", - " monai.transforms.NormalizeIntensityd(keys=['image']),\n", - " monai.transforms.ToTensord(keys=['image', 'label'])\n", - " ]\n", + " monai.transforms.CopyItemsd(keys=[\"label\"], names=[\"label_channel\"]),\n", + " monai.transforms.Lambdad(keys=[\"label_channel\"], func=lambda l: l != 0),\n", + " monai.transforms.MaskIntensityd(keys=[\"image\"], mask_key=\"label_channel\"),\n", + " monai.transforms.NormalizeIntensityd(keys=[\"image\"]),\n", + " monai.transforms.ToTensord(keys=[\"image\", \"label\"]),\n", + "]\n", "\n", "aug_transforms = [\n", - " monai.transforms.RandBiasFieldd(coeff_range=(0, 0.005), prob=0.33, keys=['image']),\n", - " monai.transforms.RandAdjustContrastd(gamma=(0.9, 1.15), prob=0.33, keys=['image']),\n", - " monai.transforms.RandGaussianNoised(prob=0.33, mean=0.0, std=np.random.uniform(0.005, 0.015),\n", - " keys=['image']),\n", - " monai.transforms.RandAffined(rotate_range=[-0.05, 0.05], shear_range=[0.001, 0.05],\n", - " scale_range=[0, 0.05], padding_mode='zeros',\n", - " mode='nearest', prob=0.33, keys=['label', 'image'])\n", - " ]\n", + " monai.transforms.RandBiasFieldd(coeff_range=(0, 0.005), prob=0.33, keys=[\"image\"]),\n", + " monai.transforms.RandAdjustContrastd(gamma=(0.9, 1.15), prob=0.33, keys=[\"image\"]),\n", + " monai.transforms.RandGaussianNoised(prob=0.33, mean=0.0, std=np.random.uniform(0.005, 0.015), keys=[\"image\"]),\n", + " monai.transforms.RandAffined(\n", + " rotate_range=[-0.05, 0.05],\n", + " shear_range=[0.001, 0.05],\n", + " scale_range=[0, 0.05],\n", + " padding_mode=\"zeros\",\n", + " mode=\"nearest\",\n", + " prob=0.33,\n", + " keys=[\"label\", \"image\"],\n", + " ),\n", + "]\n", "\n", "train_transforms = monai.transforms.Compose(base_transforms + aug_transforms + last_transforms)\n", "val_transforms = monai.transforms.Compose(base_transforms + last_transforms)\n", @@ -280,11 +285,13 @@ "source": [ "# Sanity check\n", "batch = next(iter(train_loader))\n", - "print(batch['image'].shape)\n", - "plt.subplot(1,2,1)\n", - "plt.imshow(batch['image'][0,0,...], cmap = 'gist_gray'); plt.axis('off')\n", - "plt.subplot(1,2,2)\n", - "plt.imshow(batch['label'][0,0,...], cmap = \"jet\"); plt.axis('off')\n", + "print(batch[\"image\"].shape)\n", + "plt.subplot(1, 2, 1)\n", + "plt.imshow(batch[\"image\"][0, 0, ...], cmap=\"gist_gray\")\n", + "plt.axis(\"off\")\n", + "plt.subplot(1, 2, 2)\n", + "plt.imshow(batch[\"label\"][0, 0, ...], cmap=\"jet\")\n", + "plt.axis(\"off\")\n", "plt.show()" ] }, @@ -323,14 +330,6 @@ " return label_out" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "6af2779b", - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": 11, @@ -339,18 +338,18 @@ "outputs": [], "source": [ "def picture_results(input_label, input_image, output_image):\n", - " f = plt.figure(figsize = (4, 1.5))\n", - " plt.subplot(1,3,1)\n", - " plt.imshow(torch.argmax(input_label, 1)[0,...].detach().cpu(), cmap = 'jet')\n", - " plt.axis('off')\n", + " f = plt.figure(figsize=(4, 1.5))\n", + " plt.subplot(1, 3, 1)\n", + " plt.imshow(torch.argmax(input_label, 1)[0, ...].detach().cpu(), cmap=\"jet\")\n", + " plt.axis(\"off\")\n", " plt.title(\"Label\")\n", - " plt.subplot(1,3,2)\n", - " plt.imshow(input_image[0,0,...].detach().cpu(), cmap = 'gist_gray')\n", - " plt.axis('off')\n", + " plt.subplot(1, 3, 2)\n", + " plt.imshow(input_image[0, 0, ...].detach().cpu(), cmap=\"gist_gray\")\n", + " plt.axis(\"off\")\n", " plt.title(\"Input image\")\n", - " plt.subplot(1,3,3)\n", - " plt.imshow(output_image[0,0,...].detach().cpu(), cmap = 'gist_gray')\n", - " plt.axis('off')\n", + " plt.subplot(1, 3, 3)\n", + " plt.imshow(output_image[0, 0, ...].detach().cpu(), cmap=\"gist_gray\")\n", + " plt.axis(\"off\")\n", " plt.title(\"Output image\")\n", " plt.show()" ] @@ -369,8 +368,7 @@ " for i in range(num_D): # for each discriminator\n", " num_intermediate_outputs = len(input_features_disc_fake[i])\n", " for j in range(num_intermediate_outputs): # for each layer output\n", - " unweighted_loss = criterion(input_features_disc_fake[i][j],\n", - " input_features_disc_real[i][j].detach())\n", + " unweighted_loss = criterion(input_features_disc_fake[i][j], input_features_disc_real[i][j].detach())\n", " GAN_Feat_loss += unweighted_loss * lambda_feat / num_D\n", " return GAN_Feat_loss" ] @@ -382,14 +380,16 @@ "metadata": {}, "outputs": [], "source": [ - "net = SPADE_Net(spatial_dims = 2,\n", - " in_channels = 1,\n", - " out_channels = 1,\n", - " label_nc = 6,\n", - " input_shape = input_shape,\n", - " num_channels = [16, 32, 64, 128],\n", - " z_dim = 16,\n", - " is_vae = True)" + "net = SPADE_Net(\n", + " spatial_dims=2,\n", + " in_channels=1,\n", + " out_channels=1,\n", + " label_nc=6,\n", + " input_shape=input_shape,\n", + " num_channels=[16, 32, 64, 128],\n", + " z_dim=16,\n", + " is_vae=True,\n", + ")" ] }, { @@ -399,18 +399,19 @@ "metadata": {}, "outputs": [], "source": [ - "discriminator = MultiScalePatchDiscriminator(num_d = 2,\n", - " num_layers_d = 3,\n", - " spatial_dims = 2,\n", - " num_channels = 8,\n", - " in_channels = 7,\n", - " out_channels = 7,\n", - " minimum_size_im = 128,\n", - " norm = \"INSTANCE\",\n", - " kernel_size = 3\n", - " )\n", + "discriminator = MultiScalePatchDiscriminator(\n", + " num_d=2,\n", + " num_layers_d=3,\n", + " spatial_dims=2,\n", + " num_channels=8,\n", + " in_channels=7,\n", + " out_channels=7,\n", + " minimum_size_im=128,\n", + " norm=\"INSTANCE\",\n", + " kernel_size=3,\n", + ")\n", "\n", - "adversarial_loss = PatchAdversarialLoss(reduction = \"sum\", criterion = \"hinge\")" + "adversarial_loss = PatchAdversarialLoss(reduction=\"sum\", criterion=\"hinge\")" ] }, { @@ -429,11 +430,8 @@ } ], "source": [ - "perceptual_loss = PerceptualLoss(spatial_dims = 2,\n", - " network_type = \"vgg\",\n", - " is_fake_3d = False,\n", - " pretrained = True)\n", - "perceptual_loss=perceptual_loss.to(device)" + "perceptual_loss = PerceptualLoss(spatial_dims=2, network_type=\"vgg\", is_fake_3d=False, pretrained=True)\n", + "perceptual_loss = perceptual_loss.to(device)" ] }, { @@ -443,8 +441,8 @@ "metadata": {}, "outputs": [], "source": [ - "optimizer_G = torch.optim.Adam(net.parameters(), lr = 0.0002)\n", - "optimizer_D = torch.optim.Adam(discriminator.parameters(), lr = 0.0004)" + "optimizer_G = torch.optim.Adam(net.parameters(), lr=0.0002)\n", + "optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0004)" ] }, { @@ -2237,92 +2235,96 @@ "net = net.to(device)\n", "discriminator = discriminator.to(device)\n", "torch.autograd.set_detect_anomaly(True)\n", - "losses = {'kld': [], 'perceptual': [], 'feature': [], 'generator': [], 'discriminator': []}\n", - "losses_val = {'kld': [], 'perceptual': [], 'feature': [], 'generator': [], 'discriminator': []}\n", + "losses = {\"kld\": [], \"perceptual\": [], \"feature\": [], \"generator\": [], \"discriminator\": []}\n", + "losses_val = {\"kld\": [], \"perceptual\": [], \"feature\": [], \"generator\": [], \"discriminator\": []}\n", "for epoch in range(num_epochs):\n", - " print(\"Epoch %d/%d\" %(epoch, num_epochs))\n", + " print(\"Epoch %d/%d\" % (epoch, num_epochs))\n", " train_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=120)\n", - " losses_epoch = {'kld': 0, 'perceptual': 0, 'feature': 0, 'generator': 0, 'discriminator': 0}\n", + " losses_epoch = {\"kld\": 0, \"perceptual\": 0, \"feature\": 0, \"generator\": 0, \"discriminator\": 0}\n", " for step, d in train_bar:\n", - " image = d['image'].to(device)\n", + " image = d[\"image\"].to(device)\n", " with torch.no_grad():\n", - " label = one_hot(d['label'], 6).to(device)\n", + " label = one_hot(d[\"label\"], 6).to(device)\n", " optimizer_G.zero_grad()\n", "\n", " # Losses gen\n", " out, kld_loss = net(label, image)\n", " disc_fakes, features_fakes = discriminator(torch.cat([out, label], 1))\n", - " loss_g = adversarial_loss(disc_fakes, target_is_real = True, for_discriminator = False)\n", + " loss_g = adversarial_loss(disc_fakes, target_is_real=True, for_discriminator=False)\n", " disc_reals, features_reals = discriminator(torch.cat([image, label], 1))\n", " loss_feat = feature_loss(features_fakes, features_reals, lambda_feat, device)\n", - " loss_perc = perceptual_loss(out, target = image)\n", - " total_loss = loss_g + loss_perc * lambda_perc + kld_loss * lambda_kld + loss_feat * lambda_feat\n", - " total_loss.backward(retain_graph = True)\n", + " loss_perc = perceptual_loss(out, target=image)\n", + " total_loss = loss_g + loss_perc * lambda_perc + kld_loss * lambda_kld + loss_feat * lambda_feat\n", + " total_loss.backward(retain_graph=True)\n", " optimizer_G.step()\n", "\n", " # Store\n", - " losses_epoch['kld'] += kld_loss.item()\n", - " losses_epoch['perceptual'] += loss_perc.item()\n", - " losses_epoch['generator'] += loss_g.item()\n", - " #Train disc\n", + " losses_epoch[\"kld\"] += kld_loss.item()\n", + " losses_epoch[\"perceptual\"] += loss_perc.item()\n", + " losses_epoch[\"generator\"] += loss_g.item()\n", + " # Train disc\n", " out, _ = net(label, image)\n", " disc_fakes, _ = discriminator(torch.cat([out, label], 1))\n", - " loss_d_r = adversarial_loss(disc_reals, target_is_real = True, for_discriminator = True)\n", - " loss_g_f = adversarial_loss(disc_fakes, target_is_real = False, for_discriminator = True)\n", + " loss_d_r = adversarial_loss(disc_reals, target_is_real=True, for_discriminator=True)\n", + " loss_g_f = adversarial_loss(disc_fakes, target_is_real=False, for_discriminator=True)\n", " optimizer_D.zero_grad()\n", " loss_d = loss_d_r + loss_g_f\n", " loss_d.backward()\n", " optimizer_D.step()\n", "\n", " # Store\n", - " losses_epoch['feature'] = loss_feat.item()\n", - " losses_epoch['discriminator'] = loss_d_r.item() + loss_g_f.item()\n", + " losses_epoch[\"feature\"] = loss_feat.item()\n", + " losses_epoch[\"discriminator\"] = loss_d_r.item() + loss_g_f.item()\n", "\n", " train_bar.set_postfix(\n", - " {\"kld\": kld_loss.item(),\n", - " \"perceptual\": loss_perc.item(),\n", - " \"generator\": loss_g.item(),\n", - " \"feature\": loss_feat.item(),\n", - " \"discriminator\": loss_d_r.item() + loss_g_f.item(),\n", - " })\n", + " {\n", + " \"kld\": kld_loss.item(),\n", + " \"perceptual\": loss_perc.item(),\n", + " \"generator\": loss_g.item(),\n", + " \"feature\": loss_feat.item(),\n", + " \"discriminator\": loss_d_r.item() + loss_g_f.item(),\n", + " }\n", + " )\n", "\n", " val_bar = tqdm(enumerate(val_loader), total=len(val_loader), ncols=120)\n", - " losses_epoch_val = {'kld': 0, 'perceptual': 0, 'feature': 0, 'generator': 0, 'discriminator': 0}\n", + " losses_epoch_val = {\"kld\": 0, \"perceptual\": 0, \"feature\": 0, \"generator\": 0, \"discriminator\": 0}\n", " for step, d in val_bar:\n", - " image = d['image'].to(device)\n", + " image = d[\"image\"].to(device)\n", " with torch.no_grad():\n", - " label = one_hot(d['label'], 6).to(device)\n", + " label = one_hot(d[\"label\"], 6).to(device)\n", " # Losses gen\n", " out, kld_loss = net(label, image)\n", " disc_fakes, features_fakes = discriminator(torch.cat([out, label], 1))\n", - " loss_g = adversarial_loss(disc_fakes, target_is_real = True, for_discriminator = False)\n", + " loss_g = adversarial_loss(disc_fakes, target_is_real=True, for_discriminator=False)\n", " disc_reals, features_reals = discriminator(torch.cat([image, label], 1))\n", " loss_feat = feature_loss(features_fakes, features_reals, lambda_feat, device)\n", - " loss_perc = perceptual_loss(out, target = image)\n", - " total_loss = loss_adv * loss_g + loss_perc * lambda_perc + kld_loss * lambda_kld + loss_feat * lambda_feat\n", + " loss_perc = perceptual_loss(out, target=image)\n", + " total_loss = loss_adv * loss_g + loss_perc * lambda_perc + kld_loss * lambda_kld + loss_feat * lambda_feat\n", " # Store\n", - " losses_epoch_val['kld'] += kld_loss.item()\n", - " losses_epoch_val['perceptual'] += loss_perc.item()\n", - " losses_epoch_val['generator'] += loss_g.item()\n", - " #Train disc\n", + " losses_epoch_val[\"kld\"] += kld_loss.item()\n", + " losses_epoch_val[\"perceptual\"] += loss_perc.item()\n", + " losses_epoch_val[\"generator\"] += loss_g.item()\n", + " # Train disc\n", " out, _ = net(label, image)\n", " disc_fakes, _ = discriminator(torch.cat([out, label], 1))\n", - " loss_d_r = adversarial_loss(disc_reals, target_is_real = True, for_discriminator = True)\n", - " loss_g_f = adversarial_loss(disc_fakes, target_is_real = False, for_discriminator = True)\n", + " loss_d_r = adversarial_loss(disc_reals, target_is_real=True, for_discriminator=True)\n", + " loss_g_f = adversarial_loss(disc_fakes, target_is_real=False, for_discriminator=True)\n", " loss_d = loss_adv * (loss_d_r + loss_g_f)\n", "\n", " # Store\n", - " losses_epoch_val['feature'] = loss_feat.item()\n", - " losses_epoch_val['discriminator'] = loss_d_r.item() + loss_g_f.item()\n", + " losses_epoch_val[\"feature\"] = loss_feat.item()\n", + " losses_epoch_val[\"discriminator\"] = loss_d_r.item() + loss_g_f.item()\n", "\n", " val_bar.set_postfix(\n", - " {\"kld\": kld_loss.item(),\n", - " \"perceptual\": loss_perc.item(),\n", - " \"generator\": loss_g.item(),\n", - " \"feature\": loss_feat.item(),\n", - " \"discriminator\": loss_d_r.item() + loss_g_f.item(),\n", - " })\n", - " if step == 0 and epoch%10==0:\n", + " {\n", + " \"kld\": kld_loss.item(),\n", + " \"perceptual\": loss_perc.item(),\n", + " \"generator\": loss_g.item(),\n", + " \"feature\": loss_feat.item(),\n", + " \"discriminator\": loss_d_r.item() + loss_g_f.item(),\n", + " }\n", + " )\n", + " if step == 0 and epoch % 10 == 0:\n", " picture_results(label, image, out)\n", " for key, val in losses_epoch.items():\n", " losses[key].append(val / len(train_loader))\n", @@ -2349,16 +2351,16 @@ ], "source": [ "# Plot losses\n", - "colors = ['orangered', 'royalblue', 'hotpink', 'lime', 'goldenrod']\n", - "plt.figure(figsize=(5,10))\n", + "colors = [\"orangered\", \"royalblue\", \"hotpink\", \"lime\", \"goldenrod\"]\n", + "plt.figure(figsize=(5, 10))\n", "ind = 0\n", "for key, val in losses.items():\n", - " plt.subplot(len(losses.keys()),1,ind+1)\n", - " plt.plot(val, color = colors[ind], linestyle = '-')\n", - " plt.plot(losses_val[key], color = colors[ind], linestyle = '--')\n", + " plt.subplot(len(losses.keys()), 1, ind + 1)\n", + " plt.plot(val, color=colors[ind], linestyle=\"-\")\n", + " plt.plot(losses_val[key], color=colors[ind], linestyle=\"--\")\n", " plt.title(key)\n", " plt.xlabel(\"Epochs\")\n", - " ind+=1;\n", + " ind += 1\n", "plt.tight_layout()\n", "plt.show()" ] @@ -2367,7 +2369,6 @@ "cell_type": "markdown", "id": "c3cf096f", "metadata": { - "lines_to_next_cell": 0, "pycharm": { "name": "#%%" }