Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions generative/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ class PerceptualLoss(nn.Module):
Perceptual loss using features from pretrained deep neural networks trained. The function supports networks
pretrained on: ImageNet that use the LPIPS approach from Zhang, et al. "The unreasonable effectiveness of deep
features as a perceptual metric." https://arxiv.org/abs/1801.03924 ; RadImagenet from Mei, et al. "RadImageNet: An
Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"; and MedicalNet from Chen et al.
"Med3D: Transfer Learning for 3D Medical Image Analysis" .
Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"
https://pubs.rsna.org/doi/full/10.1148/ryai.210315 ; and MedicalNet from Chen et al. "Med3D: Transfer Learning for
3D Medical Image Analysis" https://arxiv.org/abs/1904.00625 .

The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual on slices from the
three axis.
Expand All @@ -48,11 +49,14 @@ def __init__(
if spatial_dims not in [2, 3]:
raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.")

if spatial_dims == 2 and "medicalnet_" in network_type:
raise ValueError("MedicalNet networks are only compatible with ``spatial_dims=3``.")

self.spatial_dims = spatial_dims
if spatial_dims == 3 and is_fake_3d is False:
self.perceptual_function = MedicalNetPerceptualComponent(net=network_type, verbose=False)
self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False)
elif "radimagenet_" in network_type:
self.perceptual_function = RadImageNetPerceptualComponent(net=network_type, verbose=False)
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False)
else:
self.perceptual_function = LPIPS(
pretrained=True,
Expand Down Expand Up @@ -134,7 +138,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return torch.mean(loss)


class MedicalNetPerceptualComponent(nn.Module):
class MedicalNetPerceptualSimilarity(nn.Module):
"""
Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer
Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from
Expand Down Expand Up @@ -200,7 +204,7 @@ def medicalnet_intensity_normalisation(volume):
return (volume - mean) / std


class RadImageNetPerceptualComponent(nn.Module):
class RadImageNetPerceptualSimilarity(nn.Module):
"""
Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et
al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class
Expand Down
7 changes: 7 additions & 0 deletions tests/test_perceptual_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def test_1d(self):
with self.assertRaises(NotImplementedError):
PerceptualLoss(spatial_dims=1)

def test_medicalnet_on_2d_data(self):
with self.assertRaises(ValueError):
PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet10_23datasets")

with self.assertRaises(ValueError):
PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet50_23datasets")


if __name__ == "__main__":
unittest.main()