diff --git a/generative/losses/perceptual.py b/generative/losses/perceptual.py index 68b94d75..5a3640b8 100644 --- a/generative/losses/perceptual.py +++ b/generative/losses/perceptual.py @@ -45,8 +45,11 @@ def __init__( if spatial_dims not in [2, 3]: raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.") - if spatial_dims == 2 and "medicalnet_" in network_type: - raise ValueError("MedicalNet networks are only compatible with ``spatial_dims=3``.") + if (spatial_dims == 2 or is_fake_3d) and "medicalnet_" in network_type: + raise ValueError( + "MedicalNet networks are only compatible with ``spatial_dims=3``." + "Argument is_fake_3d must be set to False." + ) self.spatial_dims = spatial_dims if spatial_dims == 3 and is_fake_3d is False: