From 4170924ed558aeb42d6966fc22064043ebd429c3 Mon Sep 17 00:00:00 2001 From: binliu Date: Sun, 12 Mar 2023 12:19:10 +0000 Subject: [PATCH 1/4] add TensorRT model check to avoid evaluator error Signed-off-by: binliu --- monai/engines/evaluator.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 43964ee8bc..879e7caa6f 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -295,14 +295,16 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} - - # execute forward computation - with engine.mode(engine.network): - if engine.amp: - with torch.cuda.amp.autocast(**engine.amp_kwargs): + if hasattr(engine.network, "training"): + # execute forward computation + with engine.mode(engine.network): + if engine.amp: + with torch.cuda.amp.autocast(**engine.amp_kwargs): + engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) + else: engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) - else: - engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) + else: + engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) From 6ca348ee82bb6824025294397d821f5419c00a35 Mon Sep 17 00:00:00 2001 From: binliu Date: Tue, 14 Mar 2023 07:39:16 +0000 Subject: [PATCH 2/4] fix the network mode issue through another way Signed-off-by: binliu --- monai/engines/evaluator.py | 15 ++++++--------- monai/networks/utils.py | 11 ++++++----- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 879e7caa6f..7c6ddd5bdd 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -295,16 +295,13 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} - if hasattr(engine.network, "training"): - # execute forward computation - with engine.mode(engine.network): - if engine.amp: - with torch.cuda.amp.autocast(**engine.amp_kwargs): - engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) - else: + # execute forward computation + with engine.mode(engine.network): + if engine.amp: + with torch.cuda.amp.autocast(**engine.amp_kwargs): engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) - else: - engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) + else: + engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index d5c0629c05..9b4833fd79 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -376,12 +376,12 @@ def eval_mode(*nets: nn.Module): """ # Get original state of network(s) - training = [n for n in nets if n.training] + training = [n for n in nets if hasattr(n, "training") and n.training] try: # set to eval mode with torch.no_grad(): - yield [n.eval() for n in nets] + yield [n.eval() for n in nets if hasattr(n, "eval")] finally: # Return required networks to training for n in training: @@ -410,16 +410,17 @@ def train_mode(*nets: nn.Module): """ # Get original state of network(s) - eval_list = [n for n in nets if not n.training] + eval_list = [n for n in nets if hasattr(n, "training") and (not n.training)] try: # set to train mode with torch.set_grad_enabled(True): - yield [n.train() for n in nets] + yield [n.train() for n in nets if hasattr(n, "train")] finally: # Return required networks to eval_list for n in eval_list: - n.eval() + if hasattr(n, "eval"): + n.eval() def get_state_dict(obj: torch.nn.Module | Mapping): From 00c1ae3f4d963d66d557f207deb75a9c8327a835 Mon Sep 17 00:00:00 2001 From: binliu Date: Tue, 14 Mar 2023 13:07:27 +0000 Subject: [PATCH 3/4] add more attribute check Signed-off-by: binliu --- monai/networks/utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 9b4833fd79..a83b7d2ebb 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -375,17 +375,19 @@ def eval_mode(*nets: nn.Module): print(p(t).sum().backward()) # will correctly raise an exception as gradients are calculated """ - # Get original state of network(s) + # Get original state of network(s). + # Check the training attribute in case it's TensorRT based models which don't have this attribute. training = [n for n in nets if hasattr(n, "training") and n.training] try: # set to eval mode with torch.no_grad(): - yield [n.eval() for n in nets if hasattr(n, "eval")] + yield [n.eval() if hasattr(n, "eval") else n for n in nets] finally: # Return required networks to training for n in training: - n.train() + if hasattr(n, "train"): + n.train() @contextmanager @@ -410,12 +412,13 @@ def train_mode(*nets: nn.Module): """ # Get original state of network(s) + # Check the training attribute in case it's TensorRT based models which don't have this attribute. eval_list = [n for n in nets if hasattr(n, "training") and (not n.training)] try: # set to train mode with torch.set_grad_enabled(True): - yield [n.train() for n in nets if hasattr(n, "train")] + yield [n.train() if hasattr(n, "train") else n for n in nets if hasattr(n, "train")] finally: # Return required networks to eval_list for n in eval_list: From 2a244fe9fd661e4ad73b26f41ecbd06f4327a35c Mon Sep 17 00:00:00 2001 From: binliu Date: Tue, 14 Mar 2023 13:58:38 +0000 Subject: [PATCH 4/4] remove the legacy condition code Signed-off-by: binliu --- monai/networks/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index a83b7d2ebb..b79ae8e9bd 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -418,7 +418,7 @@ def train_mode(*nets: nn.Module): try: # set to train mode with torch.set_grad_enabled(True): - yield [n.train() if hasattr(n, "train") else n for n in nets if hasattr(n, "train")] + yield [n.train() if hasattr(n, "train") else n for n in nets] finally: # Return required networks to eval_list for n in eval_list: