Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion generative/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,16 @@ class PerceptualLoss(nn.Module):
Specifies the network architecture to use. Defaults to ``"alex"``.
is_fake_3d: if True use 2.5D approach for a 3D perceptual loss.
fake_3d_ratio: ratio of how many slices per axis are used in the 2.5D approach.
cache_dir: path to cache directory to save the pretrained network weights.
"""

def __init__(
self, spatial_dims: int, network_type: str = "alex", is_fake_3d: bool = True, fake_3d_ratio: float = 0.5
self,
spatial_dims: int,
network_type: str = "alex",
is_fake_3d: bool = True,
fake_3d_ratio: float = 0.5,
cache_dir: str | None = None,
):
super().__init__()

Expand All @@ -51,6 +57,9 @@ def __init__(
"Argument is_fake_3d must be set to False."
)

if cache_dir:
torch.hub.set_dir(cache_dir)

self.spatial_dims = spatial_dims
if spatial_dims == 3 and is_fake_3d is False:
self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False)
Expand Down