From d823b2e16a91d85728176cfcffad5714d50753b5 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Tue, 27 Dec 2022 16:35:11 +0000 Subject: [PATCH 1/6] [WIP] Add Perceptual loss based on MedicalNet 3D Networks Signed-off-by: Walter Hugo Lopez Pinaya --- generative/losses/perceptual.py | 85 ++++++++++++++++++++++++++++----- 1 file changed, 72 insertions(+), 13 deletions(-) diff --git a/generative/losses/perceptual.py b/generative/losses/perceptual.py index 682c1e76..be071567 100644 --- a/generative/losses/perceptual.py +++ b/generative/losses/perceptual.py @@ -16,10 +16,6 @@ from lpips import LPIPS -# TODO: Define model_path for lpips networks. -# TODO: Add MedicalNet for true 3D computation (https://github.com/Tencent/MedicalNet) -# TODO: Add RadImageNet for 2D computation with networks pretrained using radiological images -# (https://github.com/BMEII-AI/RadImageNet) class PerceptualLoss(nn.Module): """ Perceptual loss using features from pretrained deep neural networks trained. The function supports networks @@ -30,7 +26,8 @@ class PerceptualLoss(nn.Module): Args: spatial_dims: number of spatial dimensions. - network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``} + network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``, ``"medicalnet_resnet10_23datasets"``, + ``"medicalnet_resnet50_23datasets"``} Specifies the network architecture to use. Defaults to ``"alex"``. is_fake_3d: if True use 2.5D approach for a 3D perceptual loss. fake_3d_ratio: ratio of how many slices per axis are used in the 2.5D approach. @@ -48,15 +45,15 @@ def __init__( if spatial_dims not in [2, 3]: raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.") - if spatial_dims == 3 and is_fake_3d is False: - raise NotImplementedError("True 3D perceptual loss is not implemented. Try setting is_fake_3d=False") - self.spatial_dims = spatial_dims - self.perceptual_function = LPIPS( - pretrained=True, - net=network_type, - verbose=False, - ) + if spatial_dims == 3 and is_fake_3d is False: + self.perceptual_function = MedicalNetPerceptualComponent(net=network_type, verbose=False) + else: + self.perceptual_function = LPIPS( + pretrained=True, + net=network_type, + verbose=False, + ) self.is_fake_3d = is_fake_3d self.fake_3d_ratio = fake_3d_ratio @@ -127,5 +124,67 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: loss_coronal = self._calculate_axis_loss(input, target, spatial_axis=3) loss_axial = self._calculate_axis_loss(input, target, spatial_axis=4) loss = loss_sagittal + loss_axial + loss_coronal + if self.spatial_dims == 3 and self.is_fake_3d is False: + loss = self.perceptual_function(input, target) return torch.mean(loss) + + +class MedicalNetPerceptualComponent(nn.Module): + """ + Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer + Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from + "Warvito/MedicalNet-models". + + Args: + net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``} + Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``. + verbose: if false, mute messages from torch Hub load function. + """ + + def __init__( + self, + net: str = "medicalnet_resnet10_23datasets", + verbose: bool = False, + ) -> None: + super().__init__() + self.model = torch.hub.load("Warvito/MedicalNet-models", model=net, verbose=verbose) + self.eval() + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute perceptual loss using MedicalNet 3D networks. + The outputs are normalised across the channels, and we obtain the mean from the spatial dimensions (similar + approach to the lpips package). + """ + input = medicalnet_intensty_normalisation(input) + target = medicalnet_intensty_normalisation(target) + + # Get model outputs + outs_input = self.model.forward(input) + outs_target = self.model.forward(target) + + # Normalise through the channels + feats_input = normalize_tensor(outs_input) + feats_target = normalize_tensor(outs_target) + + results = (feats_input - feats_target) ** 2 + results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True) + + return results + + +def spatial_average_3d(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor: + return x.mean([2, 3, 4], keepdim=keepdim) + + +def normalize_tensor(x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor: + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def medicalnet_intensty_normalisation(volume): + """Based on https://github.com/Tencent/MedicalNet/blob/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/datasets/brains18.py#L133""" + mean = volume.mean() + std = volume.std() + return (volume - mean) / std From e60e53fe51a088d82821eb3d7c67057fc41a780d Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Tue, 27 Dec 2022 16:38:21 +0000 Subject: [PATCH 2/6] Add tests Signed-off-by: Walter Hugo Lopez Pinaya --- tests/test_perceptual_loss.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py index a1e8b098..9bb9ab23 100644 --- a/tests/test_perceptual_loss.py +++ b/tests/test_perceptual_loss.py @@ -28,6 +28,11 @@ (2, 1, 64, 64, 64), (2, 1, 64, 64, 64), ], + [ + {"spatial_dims": 3, "network_type": "medicalnet_resnet10_23datasets", "is_fake_3d": False}, + (2, 1, 64, 64, 64), + (2, 1, 64, 64, 64), + ], ] @@ -52,10 +57,6 @@ def test_different_shape(self): with self.assertRaises(ValueError): loss(tensor, target) - def test_true_3d(self): - with self.assertRaises(NotImplementedError): - PerceptualLoss(spatial_dims=3, is_fake_3d=False) - def test_1d(self): with self.assertRaises(NotImplementedError): PerceptualLoss(spatial_dims=1) From 24861917b6d4baa40c10b260c72cbe83b9e55e38 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Wed, 4 Jan 2023 14:36:11 +0000 Subject: [PATCH 3/6] Add gdown dependency [#158] Signed-off-by: Walter Hugo Lopez Pinaya --- requirements-dev.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index c8237cdd..6c023b84 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -13,3 +13,4 @@ matplotlib!=3.5.0 einops tensorboard>=2.11.0 nibabel>=4.0.2 +gdown>=4.4.0 From 9e3b1388e0de2c9a6545be89bba8777a457d51d9 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Wed, 4 Jan 2023 14:42:38 +0000 Subject: [PATCH 4/6] Add _validate_not_a_forked_repo line [#158] Signed-off-by: Walter Hugo Lopez Pinaya --- generative/losses/perceptual.py | 1 + 1 file changed, 1 insertion(+) diff --git a/generative/losses/perceptual.py b/generative/losses/perceptual.py index be071567..74fb3af2 100644 --- a/generative/losses/perceptual.py +++ b/generative/losses/perceptual.py @@ -148,6 +148,7 @@ def __init__( verbose: bool = False, ) -> None: super().__init__() + torch.hub._validate_not_a_forked_repo = lambda a, b, c: True self.model = torch.hub.load("Warvito/MedicalNet-models", model=net, verbose=verbose) self.eval() From 1a473fb8004f154ca3a8b4f5132ef492d7206cff Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Wed, 4 Jan 2023 14:43:25 +0000 Subject: [PATCH 5/6] Fix typo [#158] Signed-off-by: Walter Hugo Lopez Pinaya --- generative/losses/perceptual.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/generative/losses/perceptual.py b/generative/losses/perceptual.py index 74fb3af2..51086ba8 100644 --- a/generative/losses/perceptual.py +++ b/generative/losses/perceptual.py @@ -158,8 +158,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: The outputs are normalised across the channels, and we obtain the mean from the spatial dimensions (similar approach to the lpips package). """ - input = medicalnet_intensty_normalisation(input) - target = medicalnet_intensty_normalisation(target) + input = medicalnet_intensity_normalisation(input) + target = medicalnet_intensity_normalisation(target) # Get model outputs outs_input = self.model.forward(input) @@ -184,7 +184,7 @@ def normalize_tensor(x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor: return x / (norm_factor + eps) -def medicalnet_intensty_normalisation(volume): +def medicalnet_intensity_normalisation(volume): """Based on https://github.com/Tencent/MedicalNet/blob/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/datasets/brains18.py#L133""" mean = volume.mean() std = volume.std() From cb2ce6f22dd98a890b958855407a219b150cb7ef Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Wed, 4 Jan 2023 16:20:43 +0000 Subject: [PATCH 6/6] Update docstring [#158] Signed-off-by: Walter Hugo Lopez Pinaya --- generative/losses/perceptual.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/generative/losses/perceptual.py b/generative/losses/perceptual.py index 51086ba8..893ec1e8 100644 --- a/generative/losses/perceptual.py +++ b/generative/losses/perceptual.py @@ -154,9 +154,14 @@ def __init__( def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - Compute perceptual loss using MedicalNet 3D networks. - The outputs are normalised across the channels, and we obtain the mean from the spatial dimensions (similar - approach to the lpips package). + Compute perceptual loss using MedicalNet 3D networks. The input and target tensors are inputted in the + pre-trained MedicalNet that is used for feature extraction. Then, these extracted features are normalised across + the channels. Finally, we compute the difference between the input and target features and calculate the mean + value from the spatial dimensions to obtain the perceptual loss. + + Args: + input: 3D input tensor with shape BCDHW. + target: 3D target tensor with shape BCDHW. """ input = medicalnet_intensity_normalisation(input) target = medicalnet_intensity_normalisation(target)