From bcee3c521144486e8d0010f665ff629e8b5f3e4b Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sun, 26 Feb 2023 19:35:26 +0000 Subject: [PATCH] Add cache_dir Signed-off-by: Walter Hugo Lopez Pinaya --- generative/losses/perceptual.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/generative/losses/perceptual.py b/generative/losses/perceptual.py index 1d3b7b9d..98e7e55a 100644 --- a/generative/losses/perceptual.py +++ b/generative/losses/perceptual.py @@ -35,10 +35,16 @@ class PerceptualLoss(nn.Module): Specifies the network architecture to use. Defaults to ``"alex"``. is_fake_3d: if True use 2.5D approach for a 3D perceptual loss. fake_3d_ratio: ratio of how many slices per axis are used in the 2.5D approach. + cache_dir: path to cache directory to save the pretrained network weights. """ def __init__( - self, spatial_dims: int, network_type: str = "alex", is_fake_3d: bool = True, fake_3d_ratio: float = 0.5 + self, + spatial_dims: int, + network_type: str = "alex", + is_fake_3d: bool = True, + fake_3d_ratio: float = 0.5, + cache_dir: str | None = None, ): super().__init__() @@ -51,6 +57,9 @@ def __init__( "Argument is_fake_3d must be set to False." ) + if cache_dir: + torch.hub.set_dir(cache_dir) + self.spatial_dims = spatial_dims if spatial_dims == 3 and is_fake_3d is False: self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False)