diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 43964ee8bc..7c6ddd5bdd 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -295,7 +295,6 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} - # execute forward computation with engine.mode(engine.network): if engine.amp: diff --git a/monai/networks/utils.py b/monai/networks/utils.py index d5c0629c05..b79ae8e9bd 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -375,17 +375,19 @@ def eval_mode(*nets: nn.Module): print(p(t).sum().backward()) # will correctly raise an exception as gradients are calculated """ - # Get original state of network(s) - training = [n for n in nets if n.training] + # Get original state of network(s). + # Check the training attribute in case it's TensorRT based models which don't have this attribute. + training = [n for n in nets if hasattr(n, "training") and n.training] try: # set to eval mode with torch.no_grad(): - yield [n.eval() for n in nets] + yield [n.eval() if hasattr(n, "eval") else n for n in nets] finally: # Return required networks to training for n in training: - n.train() + if hasattr(n, "train"): + n.train() @contextmanager @@ -410,16 +412,18 @@ def train_mode(*nets: nn.Module): """ # Get original state of network(s) - eval_list = [n for n in nets if not n.training] + # Check the training attribute in case it's TensorRT based models which don't have this attribute. + eval_list = [n for n in nets if hasattr(n, "training") and (not n.training)] try: # set to train mode with torch.set_grad_enabled(True): - yield [n.train() for n in nets] + yield [n.train() if hasattr(n, "train") else n for n in nets] finally: # Return required networks to eval_list for n in eval_list: - n.eval() + if hasattr(n, "eval"): + n.eval() def get_state_dict(obj: torch.nn.Module | Mapping):