From b3e00c55aa5684714d32fb9bdcd6bb09da299090 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Wed, 1 Feb 2023 11:58:08 +0000 Subject: [PATCH 1/2] Flag is_fake_3d has to be set to False if you want to use PerceptualLoss with 3D networks; otherwise, error happens. Modified the error in the __init__ to account for this flag setting. --- generative/losses/perceptual.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/generative/losses/perceptual.py b/generative/losses/perceptual.py index 68b94d75..70f691f1 100644 --- a/generative/losses/perceptual.py +++ b/generative/losses/perceptual.py @@ -45,8 +45,9 @@ def __init__( if spatial_dims not in [2, 3]: raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.") - if spatial_dims == 2 and "medicalnet_" in network_type: - raise ValueError("MedicalNet networks are only compatible with ``spatial_dims=3``.") + if (spatial_dims == 2 or is_fake_3d) and "medicalnet_" in network_type: + raise ValueError("MedicalNet networks are only compatible with ``spatial_dims=3``." + "Argument is_fake_3d must be set to False.") self.spatial_dims = spatial_dims if spatial_dims == 3 and is_fake_3d is False: From e283348dee5ab49a7a98de486ba4d8fc84e3bc12 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Wed, 1 Feb 2023 13:23:24 +0000 Subject: [PATCH 2/2] Flag is_fake_3d has to be set to False if you want to use PerceptualLoss with 3D networks; otherwise, error happens. Modified the error in the __init__ to account for this flag setting. --- generative/losses/perceptual.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/generative/losses/perceptual.py b/generative/losses/perceptual.py index 70f691f1..5a3640b8 100644 --- a/generative/losses/perceptual.py +++ b/generative/losses/perceptual.py @@ -46,8 +46,10 @@ def __init__( raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.") if (spatial_dims == 2 or is_fake_3d) and "medicalnet_" in network_type: - raise ValueError("MedicalNet networks are only compatible with ``spatial_dims=3``." - "Argument is_fake_3d must be set to False.") + raise ValueError( + "MedicalNet networks are only compatible with ``spatial_dims=3``." + "Argument is_fake_3d must be set to False." + ) self.spatial_dims = spatial_dims if spatial_dims == 3 and is_fake_3d is False: