diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 306be5f2db..0c1c75a1b8 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -18,6 +18,7 @@ from monai.engines.utils import default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer +from monai.networks.utils import eval_mode from monai.transforms import Transform from monai.utils import ensure_tuple, exact_version, optional_import @@ -190,8 +191,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict inputs, targets, args, kwargs = batch # execute forward computation - self.network.eval() - with torch.no_grad(): + with eval_mode(self.network): if self.amp: with torch.cuda.amp.autocast(): predictions = self.inferer(inputs, self.network, *args, **kwargs) @@ -298,8 +298,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict # execute forward computation predictions = {Keys.IMAGE: inputs, Keys.LABEL: targets} for idx, network in enumerate(self.networks): - network.eval() - with torch.no_grad(): + with eval_mode(network): if self.amp: with torch.cuda.amp.autocast(): predictions.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 1bcccd084c..3af177f1f8 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -13,6 +13,7 @@ """ import warnings +from contextlib import contextmanager from typing import Any, Callable, Optional, Sequence, cast import torch @@ -29,6 +30,7 @@ "normal_init", "icnr_init", "pixelshuffle", + "eval_mode", ] @@ -241,3 +243,36 @@ def pixelshuffle(x: torch.Tensor, dimensions: int, scale_factor: int) -> torch.T x = x.reshape(batch_size, org_channels, *([factor] * dim + input_size[2:])) x = x.permute(permute_indices).reshape(output_size) return x + + +@contextmanager +def eval_mode(*nets: nn.Module): + """ + Set network(s) to eval mode and then return to original state at the end. + + Args: + nets: Input network(s) + + Examples + + .. code-block:: python + + t=torch.rand(1,1,16,16) + p=torch.nn.Conv2d(1,1,3) + print(p.training) # True + with eval_mode(p): + print(p.training) # False + print(p(t).sum().backward()) # will correctly raise an exception as gradients are calculated + """ + + # Get original state of network(s) + training = [n for n in nets if n.training] + + try: + # set to eval mode + with torch.no_grad(): + yield [n.eval() for n in nets] + finally: + # Return required networks to training + for n in training: + n.train() diff --git a/tests/test_ahnet.py b/tests/test_ahnet.py index 78d2cebac3..805711aba6 100644 --- a/tests/test_ahnet.py +++ b/tests/test_ahnet.py @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.blocks import FCN, MCFCN from monai.networks.nets import AHNet from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save @@ -127,8 +128,7 @@ class TestFCN(unittest.TestCase): @skip_if_quick def test_fcn_shape(self, input_param, input_shape, expected_shape): net = FCN(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) @@ -138,8 +138,7 @@ class TestFCNWithPretrain(unittest.TestCase): @skip_if_quick def test_fcn_shape(self, input_param, input_shape, expected_shape): net = test_pretrained_networks(FCN, input_param, device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) @@ -148,8 +147,7 @@ class TestMCFCN(unittest.TestCase): @parameterized.expand([TEST_CASE_MCFCN_1, TEST_CASE_MCFCN_2, TEST_CASE_MCFCN_3]) def test_mcfcn_shape(self, input_param, input_shape, expected_shape): net = MCFCN(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) @@ -158,8 +156,7 @@ class TestMCFCNWithPretrain(unittest.TestCase): @parameterized.expand([TEST_CASE_MCFCN_WITH_PRETRAIN_1, TEST_CASE_MCFCN_WITH_PRETRAIN_2]) def test_mcfcn_shape(self, input_param, input_shape, expected_shape): net = test_pretrained_networks(MCFCN, input_param, device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) @@ -174,8 +171,7 @@ class TestAHNET(unittest.TestCase): ) def test_ahnet_shape_2d(self, input_param, input_shape, expected_shape): net = AHNet(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) @@ -189,8 +185,7 @@ def test_ahnet_shape_2d(self, input_param, input_shape, expected_shape): @skip_if_quick def test_ahnet_shape_3d(self, input_param, input_shape, expected_shape): net = AHNet(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) @@ -213,8 +208,7 @@ def test_ahnet_shape(self, input_param, input_shape, expected_shape, fcn_input_p net = AHNet(**input_param).to(device) net2d = FCN(**fcn_input_param).to(device) net.copy_from(net2d) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) @@ -230,7 +224,7 @@ def test_initialize_pretrained(self): progress=True, ).to(device) input_data = torch.randn(2, 2, 32, 32, 64).to(device) - with torch.no_grad(): + with eval_mode(net): result = net.forward(input_data) self.assertEqual(result.shape, (2, 3, 32, 32, 64)) diff --git a/tests/test_autoencoder.py b/tests/test_autoencoder.py index 86b31e0361..3ee0b640c7 100644 --- a/tests/test_autoencoder.py +++ b/tests/test_autoencoder.py @@ -3,6 +3,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.layers import Act from monai.networks.nets import AutoEncoder from tests.utils import test_script_save @@ -75,8 +76,7 @@ class TestAutoEncoder(unittest.TestCase): @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape): net = AutoEncoder(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_basic_unet.py b/tests/test_basic_unet.py index c2494dc2d3..e55a9ab516 100644 --- a/tests/test_basic_unet.py +++ b/tests/test_basic_unet.py @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import BasicUNet from tests.utils import test_script_save @@ -95,8 +96,7 @@ def test_shape(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" print(input_param) net = BasicUNet(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_channel_pad.py b/tests/test_channel_pad.py index 00d0eab65a..68b109ff1a 100644 --- a/tests/test_channel_pad.py +++ b/tests/test_channel_pad.py @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.layers import ChannelPad TEST_CASES_3D = [] @@ -34,8 +35,7 @@ class TestChannelPad(unittest.TestCase): @parameterized.expand(TEST_CASES_3D) def test_shape(self, input_param, input_shape, expected_shape): net = ChannelPad(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(list(result.shape), list(expected_shape)) diff --git a/tests/test_copy_itemsd.py b/tests/test_copy_itemsd.py index 436cb5430b..10a3380b76 100644 --- a/tests/test_copy_itemsd.py +++ b/tests/test_copy_itemsd.py @@ -15,6 +15,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.transforms import CopyItemsd from monai.utils import ensure_tuple @@ -61,8 +62,7 @@ def test_array_values(self): def test_graph_tensor_values(self): device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0") net = torch.nn.PReLU().to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): pred = net(torch.tensor([[0.0, 1.0], [1.0, 2.0]], device=device)) input_data = {"pred": pred, "seg": torch.tensor([[0.0, 1.0], [1.0, 2.0]], device=device)} result = CopyItemsd(keys="pred", times=1, names="pred_1")(input_data) diff --git a/tests/test_densenet.py b/tests/test_densenet.py index 183c5443b2..517ff43654 100644 --- a/tests/test_densenet.py +++ b/tests/test_densenet.py @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import densenet121, densenet169, densenet201, densenet264 from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save @@ -66,8 +67,7 @@ class TestPretrainedDENSENET(unittest.TestCase): @skip_if_quick def test_121_3d_shape_pretrain(self, model, input_param, input_shape, expected_shape): net = test_pretrained_networks(model, input_param, device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) @@ -76,8 +76,7 @@ class TestDENSENET(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_densenet_shape(self, model, input_param, input_shape, expected_shape): net = model(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_discriminator.py b/tests/test_discriminator.py index 2123737d05..956f0f9b5b 100644 --- a/tests/test_discriminator.py +++ b/tests/test_discriminator.py @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import Discriminator from tests.utils import test_script_save @@ -42,8 +43,7 @@ class TestDiscriminator(unittest.TestCase): @parameterized.expand(CASES) def test_shape(self, input_param, input_data, expected_shape): net = Discriminator(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_downsample_block.py b/tests/test_downsample_block.py index c2da0f9a43..ac3352acb5 100644 --- a/tests/test_downsample_block.py +++ b/tests/test_downsample_block.py @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.blocks import MaxAvgPool TEST_CASES = [ @@ -41,8 +42,7 @@ class TestMaxAvgPool(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_shape, expected_shape): net = MaxAvgPool(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index 6b89c8c4fd..565188df80 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -15,6 +15,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import DynUNet from tests.utils import test_script_save @@ -107,8 +108,7 @@ class TestDynUNet(unittest.TestCase): @parameterized.expand(TEST_CASE_DYNUNET_2D + TEST_CASE_DYNUNET_3D) def test_shape(self, input_param, input_shape, expected_shape): net = DynUNet(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result[0].shape, expected_shape) diff --git a/tests/test_dynunet_block.py b/tests/test_dynunet_block.py index 3a0bfa5e7e..776bf0db90 100644 --- a/tests/test_dynunet_block.py +++ b/tests/test_dynunet_block.py @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, UnetUpBlock, get_padding from tests.utils import test_script_save @@ -70,8 +71,7 @@ class TestResBasicBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_RES_BASIC_BLOCK) def test_shape(self, input_param, input_shape, expected_shape): for net in [UnetResBlock(**input_param), UnetBasicBlock(**input_param)]: - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) @@ -94,8 +94,7 @@ class TestUpBlock(unittest.TestCase): @parameterized.expand(TEST_UP_BLOCK) def test_shape(self, input_param, input_shape, expected_shape, skip_shape): net = UnetUpBlock(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape), torch.randn(skip_shape)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_eval_mode.py b/tests/test_eval_mode.py new file mode 100644 index 0000000000..b8d9df5880 --- /dev/null +++ b/tests/test_eval_mode.py @@ -0,0 +1,31 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from monai.networks.utils import eval_mode + + +class TestEvalMode(unittest.TestCase): + def test_eval_mode(self): + t = torch.rand(1, 1, 4, 4) + p = torch.nn.Conv2d(1, 1, 3) + self.assertTrue(p.training) # True + with eval_mode(p): + self.assertFalse(p.training) # False + with self.assertRaises(RuntimeError): + p(t).sum().backward() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_fullyconnectednet.py b/tests/test_fullyconnectednet.py index 1819c4cdb9..4c7276d47a 100644 --- a/tests/test_fullyconnectednet.py +++ b/tests/test_fullyconnectednet.py @@ -3,6 +3,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import FullyConnectedNet, VarFullyConnectedNet device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -45,8 +46,7 @@ def test_fc_shape(self, dropout): @parameterized.expand(VFC_CASES) def test_vfc_shape(self, input_param, input_shape, expected_shape): net = VarFullyConnectedNet(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device))[0] self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_generator.py b/tests/test_generator.py index 46b469b111..f6206bd04a 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import Generator from tests.utils import test_script_save @@ -42,8 +43,7 @@ class TestGenerator(unittest.TestCase): @parameterized.expand(CASES) def test_shape(self, input_param, input_data, expected_shape): net = Generator(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_highresnet.py b/tests/test_highresnet.py index 10f4f41fea..a1b58e1524 100644 --- a/tests/test_highresnet.py +++ b/tests/test_highresnet.py @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import HighResNet from tests.utils import DistTestCase, TimedCall, test_script_save @@ -48,8 +49,7 @@ class TestHighResNet(DistTestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_shape(self, input_param, input_shape, expected_shape): net = HighResNet(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_integration_classification_2d.py b/tests/test_integration_classification_2d.py index aa0fd57f76..d6ac39eb38 100644 --- a/tests/test_integration_classification_2d.py +++ b/tests/test_integration_classification_2d.py @@ -21,6 +21,7 @@ import monai from monai.apps import download_and_extract from monai.metrics import compute_roc_auc +from monai.networks import eval_mode from monai.networks.nets import densenet121 from monai.transforms import AddChannel, Compose, LoadImage, RandFlip, RandRotate, RandZoom, ScaleIntensity, ToTensor from monai.utils import set_determinism @@ -102,8 +103,7 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0", print(f"epoch {epoch + 1} average loss:{epoch_loss:0.4f}") if (epoch + 1) % val_interval == 0: - model.eval() - with torch.no_grad(): + with eval_mode(model): y_pred = torch.tensor([], dtype=torch.float32, device=device) y = torch.tensor([], dtype=torch.long, device=device) for val_data in val_loader: @@ -137,10 +137,9 @@ def run_inference_test(root_dir, test_x, test_y, device="cuda:0", num_workers=10 model_filename = os.path.join(root_dir, "best_metric_model.pth") model.load_state_dict(torch.load(model_filename)) - model.eval() y_true = list() y_pred = list() - with torch.no_grad(): + with eval_mode(model): for test_data in val_loader: test_images, test_labels = test_data[0].to(device), test_data[1].to(device) pred = model(test_images).argmax(dim=1) diff --git a/tests/test_integration_segmentation_3d.py b/tests/test_integration_segmentation_3d.py index 9de7dcf362..0e1bdbe453 100644 --- a/tests/test_integration_segmentation_3d.py +++ b/tests/test_integration_segmentation_3d.py @@ -24,6 +24,7 @@ from monai.data import NiftiSaver, create_test_image_3d from monai.inferers import sliding_window_inference from monai.metrics import DiceMetric +from monai.networks import eval_mode from monai.networks.nets import UNet from monai.transforms import ( Activations, @@ -138,8 +139,7 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0): print(f"epoch {epoch +1} average loss:{epoch_loss:0.4f}") if (epoch + 1) % val_interval == 0: - model.eval() - with torch.no_grad(): + with eval_mode(model): metric_sum = 0.0 metric_count = 0 val_images = None @@ -207,8 +207,7 @@ def run_inference_test(root_dir, device="cuda:0"): model_filename = os.path.join(root_dir, "best_metric_model.pth") model.load_state_dict(torch.load(model_filename)) - model.eval() - with torch.no_grad(): + with eval_mode(model): metric_sum = 0.0 metric_count = 0 # resampling with align_corners=True or dtype=float64 will generate diff --git a/tests/test_integration_sliding_window.py b/tests/test_integration_sliding_window.py index 74c6a82350..4ed4f39951 100644 --- a/tests/test_integration_sliding_window.py +++ b/tests/test_integration_sliding_window.py @@ -22,7 +22,7 @@ from monai.data import NiftiDataset, create_test_image_3d from monai.handlers import SegmentationSaver from monai.inferers import sliding_window_inference -from monai.networks import predict_segmentation +from monai.networks import eval_mode, predict_segmentation from monai.networks.nets import UNet from monai.transforms import AddChannel from monai.utils import set_determinism @@ -40,9 +40,8 @@ def run_test(batch_size, img_name, seg_name, output_dir, device="cuda:0"): sw_batch_size = batch_size def _sliding_window_processor(_engine, batch): - net.eval() img, seg, meta_data = batch - with torch.no_grad(): + with eval_mode(net): seg_probs = sliding_window_inference(img.to(device), roi_size, sw_batch_size, net, device=device) return predict_segmentation(seg_probs) diff --git a/tests/test_se_block.py b/tests/test_se_block.py index ce3ffc89c9..baed373378 100644 --- a/tests/test_se_block.py +++ b/tests/test_se_block.py @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.blocks import SEBlock from monai.networks.layers.factories import Act, Norm from tests.utils import test_script_save @@ -63,8 +64,7 @@ class TestSEBlockLayer(unittest.TestCase): @parameterized.expand(TEST_CASES + TEST_CASES_3D) def test_shape(self, input_param, input_shape, expected_shape): net = SEBlock(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_se_blocks.py b/tests/test_se_blocks.py index 654cd1f1bf..48cc2b549f 100644 --- a/tests/test_se_blocks.py +++ b/tests/test_se_blocks.py @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.blocks import ChannelSELayer, ResidualSELayer from tests.utils import test_script_save @@ -41,8 +42,7 @@ class TestChannelSELayer(unittest.TestCase): @parameterized.expand(TEST_CASES + TEST_CASES_3D) def test_shape(self, input_param, input_shape, expected_shape): net = ChannelSELayer(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) @@ -61,8 +61,7 @@ class TestResidualSELayer(unittest.TestCase): @parameterized.expand(TEST_CASES[:1]) def test_shape(self, input_param, input_shape, expected_shape): net = ResidualSELayer(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_segresnet.py b/tests/test_segresnet.py index fc1325d94e..e39b236391 100644 --- a/tests/test_segresnet.py +++ b/tests/test_segresnet.py @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import SegResNet, SegResNetVAE from monai.utils import UpsampleMode from tests.utils import test_script_save @@ -82,8 +83,7 @@ class TestResNet(unittest.TestCase): @parameterized.expand(TEST_CASE_SEGRESNET + TEST_CASE_SEGRESNET_2) def test_shape(self, input_param, input_shape, expected_shape): net = SegResNet(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) @@ -102,7 +102,7 @@ class TestResNetVAE(unittest.TestCase): @parameterized.expand(TEST_CASE_SEGRESNET_VAE) def test_vae_shape(self, input_param, input_shape, expected_shape): net = SegResNetVAE(**input_param).to(device) - with torch.no_grad(): + with eval_mode(net): result, _ = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_segresnet_block.py b/tests/test_segresnet_block.py index 0598362619..1622e73fe0 100644 --- a/tests/test_segresnet_block.py +++ b/tests/test_segresnet_block.py @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.blocks.segresnet_block import ResBlock TEST_CASE_RESBLOCK = [] @@ -39,8 +40,7 @@ class TestResBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_RESBLOCK) def test_shape(self, input_param, input_shape, expected_shape): net = ResBlock(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_senet.py b/tests/test_senet.py index d4fdbe28a7..2e381f270e 100644 --- a/tests/test_senet.py +++ b/tests/test_senet.py @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import ( se_resnet50, se_resnet101, @@ -42,9 +43,8 @@ class TestSENET(unittest.TestCase): def test_senet_shape(self, net, net_args): input_data = torch.randn(2, 2, 64, 64, 64).to(device) expected_shape = (2, 2) - net = net(**net_args) - net = net.to(device).eval() - with torch.no_grad(): + net = net(**net_args).to(device) + with eval_mode(net): result = net(input_data) self.assertEqual(result.shape, expected_shape) @@ -65,8 +65,8 @@ def test_senet_shape(self, model, input_param): net = test_pretrained_networks(model, input_param, device) input_data = torch.randn(3, 3, 64, 64).to(device) expected_shape = (3, 2) - net = net.to(device).eval() - with torch.no_grad(): + net = net.to(device) + with eval_mode(net): result = net(input_data) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_simple_aspp.py b/tests/test_simple_aspp.py index da3ed3ecb2..fd1a38de44 100644 --- a/tests/test_simple_aspp.py +++ b/tests/test_simple_aspp.py @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.blocks import SimpleASPP TEST_CASES = [ @@ -69,8 +70,7 @@ class TestChannelSELayer(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_shape, expected_shape): net = SimpleASPP(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_skip_connection.py b/tests/test_skip_connection.py index aa6b4a35c3..23a22fc3b0 100644 --- a/tests/test_skip_connection.py +++ b/tests/test_skip_connection.py @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.layers import SkipConnection TEST_CASES_3D = [] @@ -35,8 +36,7 @@ class TestSkipConnection(unittest.TestCase): @parameterized.expand(TEST_CASES_3D) def test_shape(self, input_param, input_shape, expected_shape): net = SkipConnection(submodule=torch.nn.Softmax(dim=1), **input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_subpixel_upsample.py b/tests/test_subpixel_upsample.py index 92e12ecf6c..c83f0d67ec 100644 --- a/tests/test_subpixel_upsample.py +++ b/tests/test_subpixel_upsample.py @@ -15,6 +15,7 @@ import torch.nn as nn from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.blocks import SubpixelUpsample from monai.networks.layers.factories import Conv @@ -75,8 +76,7 @@ class TestSUBPIXEL(unittest.TestCase): @parameterized.expand(TEST_CASE_SUBPIXEL) def test_subpixel_shape(self, input_param, input_shape, expected_shape): net = SubpixelUpsample(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_unet.py b/tests/test_unet.py index ed05fce552..33fd5ce044 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.layers import Act, Norm from monai.networks.nets import UNet from tests.utils import test_script_save @@ -121,8 +122,7 @@ class TestUNET(unittest.TestCase): @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape): net = UNet(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_upsample_block.py b/tests/test_upsample_block.py index aa3a1fb90a..6f7eee7e1d 100644 --- a/tests/test_upsample_block.py +++ b/tests/test_upsample_block.py @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.blocks import UpSample from monai.utils import UpsampleMode @@ -85,8 +86,7 @@ class TestUpsample(unittest.TestCase): @parameterized.expand(TEST_CASES + TEST_CASES_EQ) def test_shape(self, input_param, input_shape, expected_shape): net = UpSample(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_varautoencoder.py b/tests/test_varautoencoder.py index b2bb1c22e9..bdfd4f8e5b 100644 --- a/tests/test_varautoencoder.py +++ b/tests/test_varautoencoder.py @@ -3,6 +3,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.layers import Act from monai.networks.nets import VarAutoEncoder from tests.utils import test_script_save @@ -70,8 +71,7 @@ class TestVarAutoEncoder(unittest.TestCase): @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape): net = VarAutoEncoder(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device))[0] self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_vnet.py b/tests/test_vnet.py index 062e171655..c7e3a3bffb 100644 --- a/tests/test_vnet.py +++ b/tests/test_vnet.py @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import VNet from tests.utils import test_script_save @@ -64,8 +65,7 @@ class TestVNet(unittest.TestCase): ) def test_vnet_shape(self, input_param, input_shape, expected_shape): net = VNet(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape)