From 5af4ecf923cdabead061f2f2747ff775462e919c Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 13 Dec 2023 15:16:43 +0000 Subject: [PATCH 1/4] Adds patchgan discriminator Signed-off-by: Mark Graham --- docs/source/networks.rst | 8 + monai/networks/nets/__init__.py | 1 + monai/networks/nets/patchgan_discriminator.py | 249 ++++++++++++++++++ tests/test_patch_gan_dicriminator.py | 179 +++++++++++++ 4 files changed, 437 insertions(+) create mode 100644 monai/networks/nets/patchgan_discriminator.py create mode 100644 tests/test_patch_gan_dicriminator.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 417fb8ac73..d0f74714f0 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -751,6 +751,14 @@ Nets .. autoclass:: VQVAE :members: +`PatchGanDiscriminator` +~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: PatchGanDiscriminator + :members: + +.. autoclass:: MultiScalePatchGanDiscriminator + :members: + Utilities --------- .. automodule:: monai.networks.utils diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 31fbd73b4e..b6fbbb2173 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -54,6 +54,7 @@ from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet from .milmodel import MILModel from .netadapter import NetAdapter +from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator from .quicknat import Quicknat from .regressor import Regressor from .regunet import GlobalNet, LocalNet, RegUNet diff --git a/monai/networks/nets/patchgan_discriminator.py b/monai/networks/nets/patchgan_discriminator.py new file mode 100644 index 0000000000..c31212ec90 --- /dev/null +++ b/monai/networks/nets/patchgan_discriminator.py @@ -0,0 +1,249 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn as nn + +from monai.networks.blocks import Convolution +from monai.networks.layers import Act + + +class MultiScalePatchDiscriminator(nn.Sequential): + """ + Multi-scale Patch-GAN discriminator based on Pix2PixHD: + + High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs + Ting-Chun Wang + + The Multi-scale discriminator made up of several PatchGAN discriminators, that process the images + at different spatial scales. + + Args: + num_d: number of discriminators + num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in the first + discriminator. Each subsequent discriminator has one additional layer, meaning the output size is halved. + spatial_dims: number of spatial dimensions (1D, 2D etc.) + channels: number of filters in the first convolutional layer (doubled for each subsequent layer) + in_channels: number of input channels + out_channels: number of output channels in each discriminator + kernel_size: kernel size of the convolution layers + activation: activation layer type + norm: normalisation type + bias: introduction of layer bias + dropout: probability of dropout applied, defaults to 0. + minimum_size_im: minimum spatial size of the input image. Introduced to make sure the architecture + requested isn't going to downsample the input image beyond value of 1. + last_conv_kernel_size: kernel size of the last convolutional layer. + """ + + def __init__( + self, + num_d: int, + num_layers_d: int, + spatial_dims: int, + channels: int, + in_channels: int, + out_channels: int = 1, + kernel_size: int = 4, + activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + norm: str | tuple = "BATCH", + bias: bool = False, + dropout: float | tuple = 0.0, + minimum_size_im: int = 256, + last_conv_kernel_size: int = 1, + ) -> None: + super().__init__() + self.num_d = num_d + self.num_layers_d = num_layers_d + self.num_channels = channels + self.padding = tuple([int((kernel_size - 1) / 2)] * spatial_dims) + for i_ in range(self.num_d): + num_layers_d_i = self.num_layers_d * (i_ + 1) + output_size = float(minimum_size_im) / (2**num_layers_d_i) + if output_size < 1: + raise AssertionError( + f"Your image size is too small to take in up to {i_} discriminators with num_layers = {num_layers_d_i}." + "Please reduce num_layers, reduce num_D or enter bigger images." + ) + subnet_d = PatchDiscriminator( + spatial_dims=spatial_dims, + channels=self.num_channels, + in_channels=in_channels, + out_channels=out_channels, + num_layers_d=num_layers_d_i, + kernel_size=kernel_size, + activation=activation, + norm=norm, + bias=bias, + padding=self.padding, + dropout=dropout, + last_conv_kernel_size=last_conv_kernel_size, + ) + + self.add_module("discriminator_%d" % i_, subnet_d) + + def forward(self, i: torch.Tensor) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]: + """ + Args: + i: Input tensor + + Returns: + list of outputs and another list of lists with the intermediate features + of each discriminator. + """ + + out: list[torch.Tensor] = [] + intermediate_features: list[list[torch.Tensor]] = [] + for disc in self.children(): + out_d: list[torch.Tensor] = disc(i) + out.append(out_d[-1]) + intermediate_features.append(out_d[:-1]) + + return out, intermediate_features + + +class PatchDiscriminator(nn.Sequential): + """ + Patch-GAN discriminator based on Pix2PixHD: + High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs. + Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu, Andrew Tao, Jan Kautz, Bryan Catanzaro + + Args: + spatial_dims: number of spatial dimensions (1D, 2D etc.) + channels: number of filters in the first convolutional layer (doubled for each subsequent layer) + in_channels: number of input channels + out_channels: number of output channels + num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in the discriminator. + kernel_size: kernel size of the convolution layers + activation: activation layer type + norm: normalisation type + bias: introduction of layer bias + padding: padding to be applied to the convolutional layers + dropout: proportion of dropout applied, defaults to 0. + last_conv_kernel_size: kernel size of the last convolutional layer. + """ + + def __init__( + self, + spatial_dims: int, + channels: int, + in_channels: int, + out_channels: int = 1, + num_layers_d: int = 3, + kernel_size: int = 4, + activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + norm: str | tuple = "BATCH", + bias: bool = False, + padding: int | Sequence[int] = 1, + dropout: float | tuple = 0.0, + last_conv_kernel_size: int | None = None, + ) -> None: + super().__init__() + self.num_layers_d = num_layers_d + self.num_channels = channels + if last_conv_kernel_size is None: + last_conv_kernel_size = kernel_size + + self.add_module( + "initial_conv", + Convolution( + spatial_dims=spatial_dims, + kernel_size=kernel_size, + in_channels=in_channels, + out_channels=channels, + act=activation, + bias=True, + norm=None, + dropout=dropout, + padding=padding, + strides=2, + ), + ) + + input_channels = channels + output_channels = channels * 2 + + # Initial Layer + for l_ in range(self.num_layers_d): + if l_ == self.num_layers_d - 1: + stride = 1 + else: + stride = 2 + layer = Convolution( + spatial_dims=spatial_dims, + kernel_size=kernel_size, + in_channels=input_channels, + out_channels=output_channels, + act=activation, + bias=bias, + norm=norm, + dropout=dropout, + padding=padding, + strides=stride, + ) + self.add_module("%d" % l_, layer) + input_channels = output_channels + output_channels = output_channels * 2 + + # Final layer + self.add_module( + "final_conv", + Convolution( + spatial_dims=spatial_dims, + kernel_size=last_conv_kernel_size, + in_channels=input_channels, + out_channels=out_channels, + bias=True, + conv_only=True, + padding=int((last_conv_kernel_size - 1) / 2), + dropout=0.0, + strides=1, + ), + ) + + self.apply(self.initialise_weights) + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + """ + Args: + x: input tensor + + Returns: + list of intermediate features, with the last element being the output. + """ + out = [x] + for submodel in self.children(): + intermediate_output = submodel(out[-1]) + out.append(intermediate_output) + + return out[1:] + + def initialise_weights(self, m: nn.Module) -> None: + """ + Initialise weights of Convolution and BatchNorm layers. + + Args: + m: instance of torch.nn.module (or of class inheriting torch.nn.module) + """ + classname = m.__class__.__name__ + if classname.find("Conv2d") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("Conv3d") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("Conv1d") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) diff --git a/tests/test_patch_gan_dicriminator.py b/tests/test_patch_gan_dicriminator.py new file mode 100644 index 0000000000..c19898e70d --- /dev/null +++ b/tests/test_patch_gan_dicriminator.py @@ -0,0 +1,179 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import MultiScalePatchDiscriminator, PatchDiscriminator +from tests.utils import test_script_save + +TEST_PATCHGAN = [ + [ + { + "num_layers_d": 3, + "spatial_dims": 2, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + }, + torch.rand([1, 3, 256, 512]), + (1, 8, 128, 256), + (1, 1, 32, 64), + ], + [ + { + "num_layers_d": 3, + "spatial_dims": 3, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + }, + torch.rand([1, 3, 256, 512, 256]), + (1, 8, 128, 256, 128), + (1, 1, 32, 64, 32), + ], +] + +TEST_MULTISCALE_PATCHGAN = [ + [ + { + "num_d": 2, + "num_layers_d": 3, + "spatial_dims": 2, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + "minimum_size_im": 256, + }, + torch.rand([1, 3, 256, 512]), + [(1, 1, 32, 64), (1, 1, 4, 8)], + [4, 7], + ], + [ + { + "num_d": 2, + "num_layers_d": 3, + "spatial_dims": 3, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + "minimum_size_im": 256, + }, + torch.rand([1, 3, 256, 512, 256]), + [(1, 1, 32, 64, 32), (1, 1, 4, 8, 4)], + [4, 7], + ], +] +TEST_TOO_SMALL_SIZE = [ + { + "num_d": 2, + "num_layers_d": 6, + "spatial_dims": 2, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + "minimum_size_im": 256, + } +] + + +class TestPatchGAN(unittest.TestCase): + @parameterized.expand(TEST_PATCHGAN) + def test_shape(self, input_param, input_data, expected_shape_feature, expected_shape_output): + net = PatchDiscriminator(**input_param) + with eval_mode(net): + result = net.forward(input_data) + self.assertEqual(tuple(result[0].shape), expected_shape_feature) + self.assertEqual(tuple(result[-1].shape), expected_shape_output) + + def test_script(self): + net = PatchDiscriminator( + num_layers_d=3, + spatial_dims=2, + channels=8, + in_channels=3, + out_channels=1, + kernel_size=3, + activation="LEAKYRELU", + norm="instance", + bias=False, + dropout=0.1, + ) + i = torch.rand([1, 3, 256, 512]) + test_script_save(net, i) + + +class TestMultiscalePatchGAN(unittest.TestCase): + @parameterized.expand(TEST_MULTISCALE_PATCHGAN) + def test_shape(self, input_param, input_data, expected_shape, features_lengths=None): + net = MultiScalePatchDiscriminator(**input_param) + with eval_mode(net): + result, features = net.forward(input_data) + for r_ind, r in enumerate(result): + self.assertEqual(tuple(r.shape), expected_shape[r_ind]) + for o_d_ind, o_d in enumerate(features): + self.assertEqual(len(o_d), features_lengths[o_d_ind]) + + def test_too_small_shape(self): + with self.assertRaises(AssertionError): + MultiScalePatchDiscriminator(**TEST_TOO_SMALL_SIZE[0]) + + def test_script(self): + net = MultiScalePatchDiscriminator( + num_d=2, + num_layers_d=3, + spatial_dims=2, + channels=8, + in_channels=3, + out_channels=1, + kernel_size=3, + activation="LEAKYRELU", + norm="instance", + bias=False, + dropout=0.1, + minimum_size_im=256, + ) + i = torch.rand([1, 3, 256, 512]) + test_script_save(net, i) + + +if __name__ == "__main__": + unittest.main() From 31414ca0c61c8be55a8967c77f745c0ad460855a Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 13 Dec 2023 15:20:45 +0000 Subject: [PATCH 2/4] Fixes docs Signed-off-by: Mark Graham --- docs/source/networks.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index d0f74714f0..520bd8d3e4 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -751,12 +751,12 @@ Nets .. autoclass:: VQVAE :members: -`PatchGanDiscriminator` +`PatchGANDiscriminator` ~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: PatchGanDiscriminator +.. autoclass:: PatchDiscriminator :members: -.. autoclass:: MultiScalePatchGanDiscriminator +.. autoclass:: MultiScalePatchDiscriminator :members: Utilities From 8eebef7b663dceaf559c91387a11b16083da9868 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 14 Dec 2023 03:51:39 -0600 Subject: [PATCH 3/4] Update monai/networks/nets/patchgan_discriminator.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/networks/nets/patchgan_discriminator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/networks/nets/patchgan_discriminator.py b/monai/networks/nets/patchgan_discriminator.py index c31212ec90..b4d49eee50 100644 --- a/monai/networks/nets/patchgan_discriminator.py +++ b/monai/networks/nets/patchgan_discriminator.py @@ -127,9 +127,9 @@ class PatchDiscriminator(nn.Sequential): out_channels: number of output channels num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in the discriminator. kernel_size: kernel size of the convolution layers - activation: activation layer type - norm: normalisation type - bias: introduction of layer bias + act: activation type and arguments. Defaults to LeakyReLU. + norm: feature normalization type and arguments. Defaults to batch norm. + bias: whether to have a bias term in convolution blocks. Defaults to False. padding: padding to be applied to the convolutional layers dropout: proportion of dropout applied, defaults to 0. last_conv_kernel_size: kernel size of the last convolutional layer. From 3a7d7759e83c4659ad5bd54e4d4b6b759c5a47d6 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 14 Dec 2023 09:56:29 +0000 Subject: [PATCH 4/4] Adds arxiv link to paper reference Signed-off-by: Mark Graham --- monai/networks/nets/patchgan_discriminator.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/monai/networks/nets/patchgan_discriminator.py b/monai/networks/nets/patchgan_discriminator.py index b4d49eee50..3b089616ce 100644 --- a/monai/networks/nets/patchgan_discriminator.py +++ b/monai/networks/nets/patchgan_discriminator.py @@ -23,9 +23,7 @@ class MultiScalePatchDiscriminator(nn.Sequential): """ Multi-scale Patch-GAN discriminator based on Pix2PixHD: - - High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs - Ting-Chun Wang + High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs (https://arxiv.org/abs/1711.11585) The Multi-scale discriminator made up of several PatchGAN discriminators, that process the images at different spatial scales. @@ -117,8 +115,8 @@ def forward(self, i: torch.Tensor) -> tuple[list[torch.Tensor], list[list[torch. class PatchDiscriminator(nn.Sequential): """ Patch-GAN discriminator based on Pix2PixHD: - High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs. - Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu, Andrew Tao, Jan Kautz, Bryan Catanzaro + High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs (https://arxiv.org/abs/1711.11585) + Args: spatial_dims: number of spatial dimensions (1D, 2D etc.)