diff --git a/generative/losses/perceptual.py b/generative/losses/perceptual.py index 893ec1e8..30314467 100644 --- a/generative/losses/perceptual.py +++ b/generative/losses/perceptual.py @@ -19,15 +19,18 @@ class PerceptualLoss(nn.Module): """ Perceptual loss using features from pretrained deep neural networks trained. The function supports networks - pretrained on ImageNet that use the LPIPS approach from: Zhang, et al. "The unreasonable effectiveness of deep - features as a perceptual metric." https://arxiv.org/abs/1801.03924 + pretrained on: ImageNet that use the LPIPS approach from Zhang, et al. "The unreasonable effectiveness of deep + features as a perceptual metric." https://arxiv.org/abs/1801.03924 ; RadImagenet from Mei, et al. "RadImageNet: An + Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"; and MedicalNet from Chen et al. + "Med3D: Transfer Learning for 3D Medical Image Analysis" . + The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual on slices from the three axis. Args: spatial_dims: number of spatial dimensions. - network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``, ``"medicalnet_resnet10_23datasets"``, - ``"medicalnet_resnet50_23datasets"``} + network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``, ``"radimagenet_resnet50"``, + ``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``} Specifies the network architecture to use. Defaults to ``"alex"``. is_fake_3d: if True use 2.5D approach for a 3D perceptual loss. fake_3d_ratio: ratio of how many slices per axis are used in the 2.5D approach. @@ -48,6 +51,8 @@ def __init__( self.spatial_dims = spatial_dims if spatial_dims == 3 and is_fake_3d is False: self.perceptual_function = MedicalNetPerceptualComponent(net=network_type, verbose=False) + elif "radimagenet_" in network_type: + self.perceptual_function = RadImageNetPerceptualComponent(net=network_type, verbose=False) else: self.perceptual_function = LPIPS( pretrained=True, @@ -116,15 +121,14 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != input.shape: raise ValueError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") - if self.spatial_dims == 2: - loss = self.perceptual_function(input, target) - elif self.spatial_dims == 3 and self.is_fake_3d: + if self.spatial_dims == 3 and self.is_fake_3d: # Compute 2.5D approach loss_sagittal = self._calculate_axis_loss(input, target, spatial_axis=2) loss_coronal = self._calculate_axis_loss(input, target, spatial_axis=3) loss_axial = self._calculate_axis_loss(input, target, spatial_axis=4) loss = loss_sagittal + loss_axial + loss_coronal - if self.spatial_dims == 3 and self.is_fake_3d is False: + else: + # 2D and real 3D cases loss = self.perceptual_function(input, target) return torch.mean(loss) @@ -194,3 +198,70 @@ def medicalnet_intensity_normalisation(volume): mean = volume.mean() std = volume.std() return (volume - mean) / std + + +class RadImageNetPerceptualComponent(nn.Module): + """ + Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et + al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class + uses torch Hub to download the networks from "Warvito/radimagenet-models". + + Args: + net: {``"radimagenet_resnet50"``} + Specifies the network architecture to use. Defaults to ``"radimagenet_resnet50"``. + verbose: if false, mute messages from torch Hub load function. + """ + + def __init__( + self, + net: str = "radimagenet_resnet50", + verbose: bool = False, + ) -> None: + super().__init__() + self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose) + self.eval() + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at + https://github.com/BMEII-AI/RadImageNet, we make sure that the input and target have 3 channels, reorder it from + 'RGB' to 'BGR', and then remove the mean components of each input data channel. The outputs are normalised + across the channels, and we obtain the mean from the spatial dimensions (similar approach to the lpips package). + """ + # If input has just 1 channel, repeat channel to have 3 channels + if input.shape[1] == 1 and target.shape[1] == 1: + input = input.repeat(1, 3, 1, 1) + target = target.repeat(1, 3, 1, 1) + + # Change order from 'RGB' to 'BGR' + input = input[:, [2, 1, 0], ...] + target = target[:, [2, 1, 0], ...] + + # Subtract mean used during training + input = subtract_mean(input) + target = subtract_mean(target) + + # Get model outputs + outs_input = self.model.forward(input) + outs_target = self.model.forward(target) + + # Normalise through the channels + feats_input = normalize_tensor(outs_input) + feats_target = normalize_tensor(outs_target) + + results = (feats_input - feats_target) ** 2 + results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True) + + return results + + +def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor: + return x.mean([2, 3], keepdim=keepdim) + + +def subtract_mean(x: torch.Tensor) -> torch.Tensor: + mean = [0.406, 0.456, 0.485] + x[:, 0, :, :] -= mean[0] + x[:, 1, :, :] -= mean[1] + x[:, 2, :, :] -= mean[2] + return x diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py index 9bb9ab23..bd1521ea 100644 --- a/tests/test_perceptual_loss.py +++ b/tests/test_perceptual_loss.py @@ -28,6 +28,21 @@ (2, 1, 64, 64, 64), (2, 1, 64, 64, 64), ], + [ + {"spatial_dims": 2, "network_type": "radimagenet_resnet50"}, + (2, 1, 64, 64), + (2, 1, 64, 64), + ], + [ + {"spatial_dims": 2, "network_type": "radimagenet_resnet50"}, + (2, 3, 64, 64), + (2, 3, 64, 64), + ], + [ + {"spatial_dims": 3, "network_type": "radimagenet_resnet50", "is_fake_3d": True, "fake_3d_ratio": 0.1}, + (2, 1, 64, 64, 64), + (2, 1, 64, 64, 64), + ], [ {"spatial_dims": 3, "network_type": "medicalnet_resnet10_23datasets", "is_fake_3d": False}, (2, 1, 64, 64, 64),