From 88b7c3783420a27d4df8b5270e6dfbb81c0f1434 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 10 May 2024 14:35:27 +0100 Subject: [PATCH] Tidy up init Signed-off-by: Mark Graham --- monai/networks/nets/patchgan_discriminator.py | 21 ++----------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/monai/networks/nets/patchgan_discriminator.py b/monai/networks/nets/patchgan_discriminator.py index 3b089616ce..74da917694 100644 --- a/monai/networks/nets/patchgan_discriminator.py +++ b/monai/networks/nets/patchgan_discriminator.py @@ -18,6 +18,7 @@ from monai.networks.blocks import Convolution from monai.networks.layers import Act +from monai.networks.utils import normal_init class MultiScalePatchDiscriminator(nn.Sequential): @@ -211,7 +212,7 @@ def __init__( ), ) - self.apply(self.initialise_weights) + self.apply(normal_init) def forward(self, x: torch.Tensor) -> list[torch.Tensor]: """ @@ -227,21 +228,3 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]: out.append(intermediate_output) return out[1:] - - def initialise_weights(self, m: nn.Module) -> None: - """ - Initialise weights of Convolution and BatchNorm layers. - - Args: - m: instance of torch.nn.module (or of class inheriting torch.nn.module) - """ - classname = m.__class__.__name__ - if classname.find("Conv2d") != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find("Conv3d") != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find("Conv1d") != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find("BatchNorm") != -1: - nn.init.normal_(m.weight.data, 1.0, 0.02) - nn.init.constant_(m.bias.data, 0)