Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 78 additions & 13 deletions generative/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@
from lpips import LPIPS


# TODO: Define model_path for lpips networks.
# TODO: Add MedicalNet for true 3D computation (https://github.com/Tencent/MedicalNet)
# TODO: Add RadImageNet for 2D computation with networks pretrained using radiological images
# (https://github.com/BMEII-AI/RadImageNet)
class PerceptualLoss(nn.Module):
"""
Perceptual loss using features from pretrained deep neural networks trained. The function supports networks
Expand All @@ -30,7 +26,8 @@ class PerceptualLoss(nn.Module):

Args:
spatial_dims: number of spatial dimensions.
network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``}
network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``, ``"medicalnet_resnet10_23datasets"``,
``"medicalnet_resnet50_23datasets"``}
Specifies the network architecture to use. Defaults to ``"alex"``.
is_fake_3d: if True use 2.5D approach for a 3D perceptual loss.
fake_3d_ratio: ratio of how many slices per axis are used in the 2.5D approach.
Expand All @@ -48,15 +45,15 @@ def __init__(
if spatial_dims not in [2, 3]:
raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.")

if spatial_dims == 3 and is_fake_3d is False:
raise NotImplementedError("True 3D perceptual loss is not implemented. Try setting is_fake_3d=False")

self.spatial_dims = spatial_dims
self.perceptual_function = LPIPS(
pretrained=True,
net=network_type,
verbose=False,
)
if spatial_dims == 3 and is_fake_3d is False:
self.perceptual_function = MedicalNetPerceptualComponent(net=network_type, verbose=False)
else:
self.perceptual_function = LPIPS(
pretrained=True,
net=network_type,
verbose=False,
)
self.is_fake_3d = is_fake_3d
self.fake_3d_ratio = fake_3d_ratio

Expand Down Expand Up @@ -127,5 +124,73 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
loss_coronal = self._calculate_axis_loss(input, target, spatial_axis=3)
loss_axial = self._calculate_axis_loss(input, target, spatial_axis=4)
loss = loss_sagittal + loss_axial + loss_coronal
if self.spatial_dims == 3 and self.is_fake_3d is False:
loss = self.perceptual_function(input, target)

return torch.mean(loss)


class MedicalNetPerceptualComponent(nn.Module):
"""
Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer
Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from
"Warvito/MedicalNet-models".

Args:
net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``.
verbose: if false, mute messages from torch Hub load function.
"""

def __init__(
self,
net: str = "medicalnet_resnet10_23datasets",
verbose: bool = False,
) -> None:
super().__init__()
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
self.model = torch.hub.load("Warvito/MedicalNet-models", model=net, verbose=verbose)
self.eval()

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute perceptual loss using MedicalNet 3D networks. The input and target tensors are inputted in the
pre-trained MedicalNet that is used for feature extraction. Then, these extracted features are normalised across
the channels. Finally, we compute the difference between the input and target features and calculate the mean
value from the spatial dimensions to obtain the perceptual loss.

Args:
input: 3D input tensor with shape BCDHW.
target: 3D target tensor with shape BCDHW.
"""
input = medicalnet_intensity_normalisation(input)
target = medicalnet_intensity_normalisation(target)

# Get model outputs
outs_input = self.model.forward(input)
outs_target = self.model.forward(target)

# Normalise through the channels
feats_input = normalize_tensor(outs_input)
feats_target = normalize_tensor(outs_target)

results = (feats_input - feats_target) ** 2
results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True)

return results


def spatial_average_3d(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor:
return x.mean([2, 3, 4], keepdim=keepdim)


def normalize_tensor(x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor:
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
return x / (norm_factor + eps)


def medicalnet_intensity_normalisation(volume):
"""Based on https://github.com/Tencent/MedicalNet/blob/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/datasets/brains18.py#L133"""
mean = volume.mean()
std = volume.std()
return (volume - mean) / std
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ matplotlib!=3.5.0
einops
tensorboard>=2.11.0
nibabel>=4.0.2
gdown>=4.4.0
9 changes: 5 additions & 4 deletions tests/test_perceptual_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
(2, 1, 64, 64, 64),
(2, 1, 64, 64, 64),
],
[
{"spatial_dims": 3, "network_type": "medicalnet_resnet10_23datasets", "is_fake_3d": False},
(2, 1, 64, 64, 64),
(2, 1, 64, 64, 64),
],
]


Expand All @@ -52,10 +57,6 @@ def test_different_shape(self):
with self.assertRaises(ValueError):
loss(tensor, target)

def test_true_3d(self):
with self.assertRaises(NotImplementedError):
PerceptualLoss(spatial_dims=3, is_fake_3d=False)

def test_1d(self):
with self.assertRaises(NotImplementedError):
PerceptualLoss(spatial_dims=1)
Expand Down