From 29264e0a3a72eb7de51ba139583a3e01ed390516 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Thu, 23 Feb 2023 11:20:11 +0000 Subject: [PATCH 1/2] Add param.requires_grad = False Signed-off-by: Walter Hugo Lopez Pinaya --- generative/losses/perceptual.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/generative/losses/perceptual.py b/generative/losses/perceptual.py index 1d3b7b9d..acdfeaa5 100644 --- a/generative/losses/perceptual.py +++ b/generative/losses/perceptual.py @@ -137,6 +137,9 @@ def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = self.model = torch.hub.load("Warvito/MedicalNet-models", model=net, verbose=verbose) self.eval() + for param in self.parameters(): + param.requires_grad = False + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Compute perceptual loss using MedicalNet 3D networks. The input and target tensors are inputted in the @@ -198,6 +201,9 @@ def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose) self.eval() + for param in self.parameters(): + param.requires_grad = False + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at From bf8835d53a7207b6519756f070f4b3e7a6180e32 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Thu, 23 Feb 2023 11:20:35 +0000 Subject: [PATCH 2/2] Fix utils imports Signed-off-by: Walter Hugo Lopez Pinaya --- tests/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 601bd9e9..1d5b8e9c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,8 +1,6 @@ # COPIED FROM https://github.com/Project-MONAI/MONAI/blob/fdd07f36ecb91cfcd491533f4792e1a67a9f89fc/tests/utils.py # --------------------------------------------------------------- - -from __future__ import annotations - +# # Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.