diff --git a/generative/losses/perceptual.py b/generative/losses/perceptual.py index 682c1e76..893ec1e8 100644 --- a/generative/losses/perceptual.py +++ b/generative/losses/perceptual.py @@ -16,10 +16,6 @@ from lpips import LPIPS -# TODO: Define model_path for lpips networks. -# TODO: Add MedicalNet for true 3D computation (https://github.com/Tencent/MedicalNet) -# TODO: Add RadImageNet for 2D computation with networks pretrained using radiological images -# (https://github.com/BMEII-AI/RadImageNet) class PerceptualLoss(nn.Module): """ Perceptual loss using features from pretrained deep neural networks trained. The function supports networks @@ -30,7 +26,8 @@ class PerceptualLoss(nn.Module): Args: spatial_dims: number of spatial dimensions. - network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``} + network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``, ``"medicalnet_resnet10_23datasets"``, + ``"medicalnet_resnet50_23datasets"``} Specifies the network architecture to use. Defaults to ``"alex"``. is_fake_3d: if True use 2.5D approach for a 3D perceptual loss. fake_3d_ratio: ratio of how many slices per axis are used in the 2.5D approach. @@ -48,15 +45,15 @@ def __init__( if spatial_dims not in [2, 3]: raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.") - if spatial_dims == 3 and is_fake_3d is False: - raise NotImplementedError("True 3D perceptual loss is not implemented. Try setting is_fake_3d=False") - self.spatial_dims = spatial_dims - self.perceptual_function = LPIPS( - pretrained=True, - net=network_type, - verbose=False, - ) + if spatial_dims == 3 and is_fake_3d is False: + self.perceptual_function = MedicalNetPerceptualComponent(net=network_type, verbose=False) + else: + self.perceptual_function = LPIPS( + pretrained=True, + net=network_type, + verbose=False, + ) self.is_fake_3d = is_fake_3d self.fake_3d_ratio = fake_3d_ratio @@ -127,5 +124,73 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: loss_coronal = self._calculate_axis_loss(input, target, spatial_axis=3) loss_axial = self._calculate_axis_loss(input, target, spatial_axis=4) loss = loss_sagittal + loss_axial + loss_coronal + if self.spatial_dims == 3 and self.is_fake_3d is False: + loss = self.perceptual_function(input, target) return torch.mean(loss) + + +class MedicalNetPerceptualComponent(nn.Module): + """ + Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer + Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from + "Warvito/MedicalNet-models". + + Args: + net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``} + Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``. + verbose: if false, mute messages from torch Hub load function. + """ + + def __init__( + self, + net: str = "medicalnet_resnet10_23datasets", + verbose: bool = False, + ) -> None: + super().__init__() + torch.hub._validate_not_a_forked_repo = lambda a, b, c: True + self.model = torch.hub.load("Warvito/MedicalNet-models", model=net, verbose=verbose) + self.eval() + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute perceptual loss using MedicalNet 3D networks. The input and target tensors are inputted in the + pre-trained MedicalNet that is used for feature extraction. Then, these extracted features are normalised across + the channels. Finally, we compute the difference between the input and target features and calculate the mean + value from the spatial dimensions to obtain the perceptual loss. + + Args: + input: 3D input tensor with shape BCDHW. + target: 3D target tensor with shape BCDHW. + """ + input = medicalnet_intensity_normalisation(input) + target = medicalnet_intensity_normalisation(target) + + # Get model outputs + outs_input = self.model.forward(input) + outs_target = self.model.forward(target) + + # Normalise through the channels + feats_input = normalize_tensor(outs_input) + feats_target = normalize_tensor(outs_target) + + results = (feats_input - feats_target) ** 2 + results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True) + + return results + + +def spatial_average_3d(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor: + return x.mean([2, 3, 4], keepdim=keepdim) + + +def normalize_tensor(x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor: + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def medicalnet_intensity_normalisation(volume): + """Based on https://github.com/Tencent/MedicalNet/blob/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/datasets/brains18.py#L133""" + mean = volume.mean() + std = volume.std() + return (volume - mean) / std diff --git a/requirements-dev.txt b/requirements-dev.txt index c8237cdd..6c023b84 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -13,3 +13,4 @@ matplotlib!=3.5.0 einops tensorboard>=2.11.0 nibabel>=4.0.2 +gdown>=4.4.0 diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py index a1e8b098..9bb9ab23 100644 --- a/tests/test_perceptual_loss.py +++ b/tests/test_perceptual_loss.py @@ -28,6 +28,11 @@ (2, 1, 64, 64, 64), (2, 1, 64, 64, 64), ], + [ + {"spatial_dims": 3, "network_type": "medicalnet_resnet10_23datasets", "is_fake_3d": False}, + (2, 1, 64, 64, 64), + (2, 1, 64, 64, 64), + ], ] @@ -52,10 +57,6 @@ def test_different_shape(self): with self.assertRaises(ValueError): loss(tensor, target) - def test_true_3d(self): - with self.assertRaises(NotImplementedError): - PerceptualLoss(spatial_dims=3, is_fake_3d=False) - def test_1d(self): with self.assertRaises(NotImplementedError): PerceptualLoss(spatial_dims=1)