diff --git a/generative/losses/perceptual.py b/generative/losses/perceptual.py index 30314467..932ae613 100644 --- a/generative/losses/perceptual.py +++ b/generative/losses/perceptual.py @@ -21,8 +21,9 @@ class PerceptualLoss(nn.Module): Perceptual loss using features from pretrained deep neural networks trained. The function supports networks pretrained on: ImageNet that use the LPIPS approach from Zhang, et al. "The unreasonable effectiveness of deep features as a perceptual metric." https://arxiv.org/abs/1801.03924 ; RadImagenet from Mei, et al. "RadImageNet: An - Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"; and MedicalNet from Chen et al. - "Med3D: Transfer Learning for 3D Medical Image Analysis" . + Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning" + https://pubs.rsna.org/doi/full/10.1148/ryai.210315 ; and MedicalNet from Chen et al. "Med3D: Transfer Learning for + 3D Medical Image Analysis" https://arxiv.org/abs/1904.00625 . The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual on slices from the three axis. @@ -48,11 +49,14 @@ def __init__( if spatial_dims not in [2, 3]: raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.") + if spatial_dims == 2 and "medicalnet_" in network_type: + raise ValueError("MedicalNet networks are only compatible with ``spatial_dims=3``.") + self.spatial_dims = spatial_dims if spatial_dims == 3 and is_fake_3d is False: - self.perceptual_function = MedicalNetPerceptualComponent(net=network_type, verbose=False) + self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False) elif "radimagenet_" in network_type: - self.perceptual_function = RadImageNetPerceptualComponent(net=network_type, verbose=False) + self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False) else: self.perceptual_function = LPIPS( pretrained=True, @@ -134,7 +138,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: return torch.mean(loss) -class MedicalNetPerceptualComponent(nn.Module): +class MedicalNetPerceptualSimilarity(nn.Module): """ Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from @@ -200,7 +204,7 @@ def medicalnet_intensity_normalisation(volume): return (volume - mean) / std -class RadImageNetPerceptualComponent(nn.Module): +class RadImageNetPerceptualSimilarity(nn.Module): """ Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py index bd1521ea..ff20758e 100644 --- a/tests/test_perceptual_loss.py +++ b/tests/test_perceptual_loss.py @@ -76,6 +76,13 @@ def test_1d(self): with self.assertRaises(NotImplementedError): PerceptualLoss(spatial_dims=1) + def test_medicalnet_on_2d_data(self): + with self.assertRaises(ValueError): + PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet10_23datasets") + + with self.assertRaises(ValueError): + PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet50_23datasets") + if __name__ == "__main__": unittest.main()