diff --git a/generative/losses/perceptual.py b/generative/losses/perceptual.py index 737bf076..c5ad80f0 100644 --- a/generative/losses/perceptual.py +++ b/generative/losses/perceptual.py @@ -14,6 +14,8 @@ import torch import torch.nn as nn from lpips import LPIPS +from torchvision.models import ResNet50_Weights, resnet50 +from torchvision.models.feature_extraction import create_feature_extractor class PerceptualLoss(nn.Module): @@ -22,8 +24,9 @@ class PerceptualLoss(nn.Module): pretrained on: ImageNet that use the LPIPS approach from Zhang, et al. "The unreasonable effectiveness of deep features as a perceptual metric." https://arxiv.org/abs/1801.03924 ; RadImagenet from Mei, et al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning" - https://pubs.rsna.org/doi/full/10.1148/ryai.210315 ; and MedicalNet from Chen et al. "Med3D: Transfer Learning for - 3D Medical Image Analysis" https://arxiv.org/abs/1904.00625 . + https://pubs.rsna.org/doi/full/10.1148/ryai.210315 ; MedicalNet from Chen et al. "Med3D: Transfer Learning for + 3D Medical Image Analysis" https://arxiv.org/abs/1904.00625 ; + and ResNet50 from Torchvision: https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html . The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual on slices from the three axis. @@ -31,11 +34,19 @@ class PerceptualLoss(nn.Module): Args: spatial_dims: number of spatial dimensions. network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``, ``"radimagenet_resnet50"``, - ``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``} + ``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``, ``"resnet50"``} Specifies the network architecture to use. Defaults to ``"alex"``. is_fake_3d: if True use 2.5D approach for a 3D perceptual loss. fake_3d_ratio: ratio of how many slices per axis are used in the 2.5D approach. cache_dir: path to cache directory to save the pretrained network weights. + pretrained: whether to load pretrained weights. This argument only works when using networks from + LIPIS or Torchvision. Defaults to ``"True"``. + pretrained_path: if `pretrained` is `True`, users can specify a weights file to be loaded + via using this argument. This argument only works when ``"network_type"`` is "resnet50". + Defaults to `None`. + pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to + extract the expected state dict. This argument only works when ``"network_type"`` is "resnet50". + Defaults to `None`. """ def __init__( @@ -45,6 +56,9 @@ def __init__( is_fake_3d: bool = True, fake_3d_ratio: float = 0.5, cache_dir: str | None = None, + pretrained: bool = True, + pretrained_path: str | None = None, + pretrained_state_dict_key: str | None = None, ): super().__init__() @@ -65,8 +79,15 @@ def __init__( self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False) elif "radimagenet_" in network_type: self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False) + elif network_type == "resnet50": + self.perceptual_function = TorchvisionModelPerceptualSimilarity( + net=network_type, + pretrained=pretrained, + pretrained_path=pretrained_path, + pretrained_state_dict_key=pretrained_state_dict_key, + ) else: - self.perceptual_function = LPIPS(pretrained=True, net=network_type, verbose=False) + self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False) self.is_fake_3d = is_fake_3d self.fake_3d_ratio = fake_3d_ratio @@ -247,10 +268,95 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: return results +class TorchvisionModelPerceptualSimilarity(nn.Module): + """ + Component to perform the perceptual evaluation with TorchVision models. + Currently, only ResNet50 is supported. The network structure is based on: + https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html + + Args: + net: {``"resnet50"``} + Specifies the network architecture to use. Defaults to ``"resnet50"``. + pretrained: whether to load pretrained weights. Defaults to `True`. + pretrained_path: if `pretrained` is `True`, users can specify a weights file to be loaded + via using this argument. Defaults to `None`. + pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to + extract the expected state dict. Defaults to `None`. + """ + + def __init__( + self, + net: str = "resnet50", + pretrained: bool = True, + pretrained_path: str | None = None, + pretrained_state_dict_key: str | None = None, + ) -> None: + super().__init__() + supported_networks = ["resnet50"] + if net not in supported_networks: + raise NotImplementedError( + f"'net' {net} is not supported, please select a network from {supported_networks}." + ) + + if pretrained_path is None: + network = resnet50(weights=ResNet50_Weights.DEFAULT if pretrained else None) + else: + network = resnet50(weights=None) + if pretrained is True: + state_dict = torch.load(pretrained_path) + if pretrained_state_dict_key is not None: + state_dict = state_dict[pretrained_state_dict_key] + network.load_state_dict(state_dict) + self.final_layer = "layer4.2.relu_2" + self.model = create_feature_extractor(network, [self.final_layer]) + self.eval() + + for param in self.parameters(): + param.requires_grad = False + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at + https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html#torchvision.models.ResNet50_Weights, + we make sure that the input and target have 3 channels, and then do Z-Score normalization. + The outputs are normalised across the channels, and we obtain the mean from the spatial dimensions (similar approach to the lpips package). + """ + # If input has just 1 channel, repeat channel to have 3 channels + if input.shape[1] == 1 and target.shape[1] == 1: + input = input.repeat(1, 3, 1, 1) + target = target.repeat(1, 3, 1, 1) + + # Input normalization + input = torchvision_zscore_norm(input) + target = torchvision_zscore_norm(target) + + # Get model outputs + outs_input = self.model.forward(input)[self.final_layer] + outs_target = self.model.forward(target)[self.final_layer] + + # Normalise through the channels + feats_input = normalize_tensor(outs_input) + feats_target = normalize_tensor(outs_target) + + results = (feats_input - feats_target) ** 2 + results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True) + + return results + + def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor: return x.mean([2, 3], keepdim=keepdim) +def torchvision_zscore_norm(x: torch.Tensor) -> torch.Tensor: + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + x[:, 0, :, :] = (x[:, 0, :, :] - mean[0]) / std[0] + x[:, 1, :, :] = (x[:, 1, :, :] - mean[1]) / std[1] + x[:, 2, :, :] = (x[:, 2, :, :] - mean[2]) / std[2] + return x + + def subtract_mean(x: torch.Tensor) -> torch.Tensor: mean = [0.406, 0.456, 0.485] x[:, 0, :, :] -= mean[0] diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py index 08ef9034..c96bec71 100644 --- a/tests/test_perceptual_loss.py +++ b/tests/test_perceptual_loss.py @@ -14,9 +14,8 @@ import unittest import torch -from parameterized import parameterized - from generative.losses import PerceptualLoss +from parameterized import parameterized TEST_CASES = [ [{"spatial_dims": 2, "network_type": "squeeze"}, (2, 1, 64, 64), (2, 1, 64, 64)], @@ -37,6 +36,11 @@ (2, 1, 64, 64, 64), (2, 1, 64, 64, 64), ], + [ + {"spatial_dims": 3, "network_type": "resnet50", "is_fake_3d": True, "pretrained": True, "fake_3d_ratio": 0.2}, + (2, 1, 64, 64, 64), + (2, 1, 64, 64, 64), + ], ]