From 364eb99a1f1d0b6b2b6c80fdab61f747d049dca9 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 8 Jan 2024 15:32:37 +0000 Subject: [PATCH 01/17] add compile arg Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/engines/evaluator.py | 21 +++++++++++++++++---- monai/engines/trainer.py | 20 +++++++++++++++++--- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 119853d5c5..0ba2f9edc1 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -17,6 +17,7 @@ from torch.utils.data import DataLoader from monai.config import IgniteInfo, KeysCollection +from monai.data import MetaTensor from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer @@ -25,7 +26,7 @@ from monai.utils import ForwardMode, ensure_tuple, min_version, optional_import from monai.utils.enums import CommonKeys as Keys from monai.utils.enums import EngineStatsKeys as ESKeys -from monai.utils.module import look_up_option +from monai.utils.module import look_up_option, pytorch_after if TYPE_CHECKING: from ignite.engine import Engine, EventEnum @@ -213,6 +214,10 @@ class SupervisedEvaluator(Evaluator): `device`, `non_blocking`. amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. + compile: whether to use compile, default is False. + If set to True, the inputs will be converted to `torch.Tensor` internally. + compile_kwargs: dict of the args for `torch.compile()` API, for more details: + https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile. """ @@ -238,6 +243,8 @@ def __init__( decollate: bool = True, to_kwargs: dict | None = None, amp_kwargs: dict | None = None, + compile: bool = False, + compile_kwargs: dict = {}, ) -> None: super().__init__( device=device, @@ -259,8 +266,12 @@ def __init__( to_kwargs=to_kwargs, amp_kwargs=amp_kwargs, ) - - self.network = network + if compile: + assert pytorch_after(2, 1) + self.network = torch.compile(network, **compile_kwargs) + else: + self.network = network + self.compile = compile self.inferer = SimpleInferer() if inferer is None else inferer def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Tensor]) -> dict: @@ -288,7 +299,9 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten kwargs: dict = {} else: inputs, targets, args, kwargs = batch - + if self.compile: + inputs = torch.Tensor(inputs) if isinstance(inputs, MetaTensor) else inputs + targets = torch.Tensor(targets) if isinstance(targets, MetaTensor) else targets # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # execute forward computation diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 61b7028e11..91e5aa1fab 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -18,11 +18,13 @@ from torch.utils.data import DataLoader from monai.config import IgniteInfo +from monai.data import MetaTensor from monai.engines.utils import IterationEvents, default_make_latent, default_metric_cmp_fn, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform from monai.utils import GanKeys, min_version, optional_import +from monai.utils.module import pytorch_after from monai.utils.enums import CommonKeys as Keys from monai.utils.enums import EngineStatsKeys as ESKeys @@ -125,7 +127,10 @@ class SupervisedTrainer(Trainer): `device`, `non_blocking`. amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. - + compile: whether to use compile, default is False. + If set to True, the inputs will be converted to `torch.Tensor` internally. + compile_kwargs: dict of the args for `torch.compile()` API, for more details: + https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile. """ def __init__( @@ -153,6 +158,8 @@ def __init__( optim_set_to_none: bool = False, to_kwargs: dict | None = None, amp_kwargs: dict | None = None, + compile: bool = False, + compile_kwargs: dict = {}, ) -> None: super().__init__( device=device, @@ -174,8 +181,12 @@ def __init__( to_kwargs=to_kwargs, amp_kwargs=amp_kwargs, ) - - self.network = network + if compile: + assert pytorch_after(2, 1) + self.network = torch.compile(network, **compile_kwargs) + else: + self.network = network + self.compile = compile self.optimizer = optimizer self.loss_function = loss_function self.inferer = SimpleInferer() if inferer is None else inferer @@ -207,6 +218,9 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso kwargs: dict = {} else: inputs, targets, args, kwargs = batch + if self.compile: + inputs = torch.Tensor(inputs) if isinstance(inputs, MetaTensor) else inputs + targets = torch.Tensor(targets) if isinstance(targets, MetaTensor) else targets # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} From ec2090d168460e3c4cf6c03c0cdb112db7d1c90b Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 8 Jan 2024 15:47:49 +0000 Subject: [PATCH 02/17] fix flake8 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/engines/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 91e5aa1fab..3b0d0f9cb2 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -24,9 +24,9 @@ from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform from monai.utils import GanKeys, min_version, optional_import -from monai.utils.module import pytorch_after from monai.utils.enums import CommonKeys as Keys from monai.utils.enums import EngineStatsKeys as ESKeys +from monai.utils.module import pytorch_after if TYPE_CHECKING: from ignite.engine import Engine, EventEnum From a57b1886a9a0a5c1535ecb6386a8466315566cef Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 9 Jan 2024 17:22:09 +0800 Subject: [PATCH 03/17] Update monai/engines/evaluator.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/engines/evaluator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 0ba2f9edc1..d0683e94c0 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -267,10 +267,10 @@ def __init__( amp_kwargs=amp_kwargs, ) if compile: - assert pytorch_after(2, 1) - self.network = torch.compile(network, **compile_kwargs) - else: - self.network = network + if pytorch_after(2, 1): + self.network = torch.compile(network, **compile_kwargs) + else: + warnings.warn("Network compilation (compile=True) not supported for Pytorch versions before 2.2, no compilation done") self.compile = compile self.inferer = SimpleInferer() if inferer is None else inferer From 3f9dfb25032508206ed537632dda22b30df7d829 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 9 Jan 2024 17:25:01 +0800 Subject: [PATCH 04/17] address comments Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/engines/evaluator.py | 3 ++- monai/engines/trainer.py | 7 +++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index d0683e94c0..e159742a15 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -11,6 +11,7 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence import torch @@ -270,7 +271,7 @@ def __init__( if pytorch_after(2, 1): self.network = torch.compile(network, **compile_kwargs) else: - warnings.warn("Network compilation (compile=True) not supported for Pytorch versions before 2.2, no compilation done") + warnings.warn("Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done") self.compile = compile self.inferer = SimpleInferer() if inferer is None else inferer diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 3b0d0f9cb2..5da474f420 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -11,6 +11,7 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence import torch @@ -182,8 +183,10 @@ def __init__( amp_kwargs=amp_kwargs, ) if compile: - assert pytorch_after(2, 1) - self.network = torch.compile(network, **compile_kwargs) + if pytorch_after(2, 1): + self.network = torch.compile(network, **compile_kwargs) + else: + warnings.warn("Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done") else: self.network = network self.compile = compile From 707259d2050a75021be470f857ca0120aa1c17ca Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 9 Jan 2024 17:28:31 +0800 Subject: [PATCH 05/17] fix flake8 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/engines/evaluator.py | 4 +++- monai/engines/trainer.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index e159742a15..c5b7ed5d29 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -271,7 +271,9 @@ def __init__( if pytorch_after(2, 1): self.network = torch.compile(network, **compile_kwargs) else: - warnings.warn("Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done") + warnings.warn( + "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done" + ) self.compile = compile self.inferer = SimpleInferer() if inferer is None else inferer diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 5da474f420..1c15f922af 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -186,7 +186,9 @@ def __init__( if pytorch_after(2, 1): self.network = torch.compile(network, **compile_kwargs) else: - warnings.warn("Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done") + warnings.warn( + "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done" + ) else: self.network = network self.compile = compile From 1c5a1fd97a00c8117e6d152cde8ba0f7375984ef Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 9 Jan 2024 17:44:43 +0800 Subject: [PATCH 06/17] fix ci Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/engines/evaluator.py | 3 ++- monai/engines/trainer.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index c5b7ed5d29..5c6846fdd5 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -245,7 +245,7 @@ def __init__( to_kwargs: dict | None = None, amp_kwargs: dict | None = None, compile: bool = False, - compile_kwargs: dict = {}, + compile_kwargs: dict | None = None, ) -> None: super().__init__( device=device, @@ -269,6 +269,7 @@ def __init__( ) if compile: if pytorch_after(2, 1): + compile_kwargs = {} if compile_kwargs is None else compile_kwargs self.network = torch.compile(network, **compile_kwargs) else: warnings.warn( diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 1c15f922af..c13f34bd1a 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -160,7 +160,7 @@ def __init__( to_kwargs: dict | None = None, amp_kwargs: dict | None = None, compile: bool = False, - compile_kwargs: dict = {}, + compile_kwargs: dict | None = None, ) -> None: super().__init__( device=device, @@ -184,6 +184,7 @@ def __init__( ) if compile: if pytorch_after(2, 1): + compile_kwargs = {} if compile_kwargs is None else compile_kwargs self.network = torch.compile(network, **compile_kwargs) else: warnings.warn( From b65cbce7f7186584fa517df3d4d06991f7f19511 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 10 Jan 2024 10:55:05 +0800 Subject: [PATCH 07/17] minor fix Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/engines/evaluator.py | 3 ++- monai/engines/trainer.py | 5 ++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 5c6846fdd5..8b0c140bfa 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -270,11 +270,12 @@ def __init__( if compile: if pytorch_after(2, 1): compile_kwargs = {} if compile_kwargs is None else compile_kwargs - self.network = torch.compile(network, **compile_kwargs) + network = torch.compile(network, **compile_kwargs) else: warnings.warn( "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done" ) + self.network = network self.compile = compile self.inferer = SimpleInferer() if inferer is None else inferer diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index c13f34bd1a..96e8f02eab 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -185,13 +185,12 @@ def __init__( if compile: if pytorch_after(2, 1): compile_kwargs = {} if compile_kwargs is None else compile_kwargs - self.network = torch.compile(network, **compile_kwargs) + network = torch.compile(network, **compile_kwargs) else: warnings.warn( "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done" ) - else: - self.network = network + self.network = network self.compile = compile self.optimizer = optimizer self.loss_function = loss_function From c3b9b96ac41c35f81d021b203c73a79279bddb0f Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 16 Jan 2024 17:58:31 +0800 Subject: [PATCH 08/17] add fixme Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/engines/evaluator.py | 1 + monai/engines/trainer.py | 1 + 2 files changed, 2 insertions(+) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 8b0c140bfa..ccae3de883 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -304,6 +304,7 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten kwargs: dict = {} else: inputs, targets, args, kwargs = batch + # FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026 if self.compile: inputs = torch.Tensor(inputs) if isinstance(inputs, MetaTensor) else inputs targets = torch.Tensor(targets) if isinstance(targets, MetaTensor) else targets diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 96e8f02eab..e0c4e856b7 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -223,6 +223,7 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso kwargs: dict = {} else: inputs, targets, args, kwargs = batch + # FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026 if self.compile: inputs = torch.Tensor(inputs) if isinstance(inputs, MetaTensor) else inputs targets = torch.Tensor(targets) if isinstance(targets, MetaTensor) else targets From c6fbe8ccf42fc5bbe17f01e5c0b0839d89fc6c81 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 17 Jan 2024 15:02:14 +0800 Subject: [PATCH 09/17] add meta back Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/engines/trainer.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index e0c4e856b7..e773327bc2 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -225,8 +225,12 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso inputs, targets, args, kwargs = batch # FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026 if self.compile: - inputs = torch.Tensor(inputs) if isinstance(inputs, MetaTensor) else inputs - targets = torch.Tensor(targets) if isinstance(targets, MetaTensor) else targets + inputs_meta, targets_meta = None, None + if isinstance(inputs, MetaTensor): + inputs, inputs_meta = inputs.as_tensor(), inputs.meta + if isinstance(targets, MetaTensor): + targets, targets_meta = targets.as_tensor(), targets.meta + # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} @@ -252,7 +256,11 @@ def _compute_pred_loss(): engine.fire_event(IterationEvents.BACKWARD_COMPLETED) engine.optimizer.step() engine.fire_event(IterationEvents.MODEL_COMPLETED) - + if self.compile: + if inputs_meta is not None: + engine.state.output[Keys.IMAGE] = MetaTensor(inputs, meta=inputs_meta) + if targets_meta is not None: + engine.state.output[Keys.LABEL] = MetaTensor(targets, meta=targets_meta) return engine.state.output From 631b24fdce821c302a344310d430d37be0bc5729 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Jan 2024 07:02:54 +0000 Subject: [PATCH 10/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/engines/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index e773327bc2..d0739d5260 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -230,7 +230,7 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso inputs, inputs_meta = inputs.as_tensor(), inputs.meta if isinstance(targets, MetaTensor): targets, targets_meta = targets.as_tensor(), targets.meta - + # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} From aeb4be97660201ee120a8183b48e59c3aafc3a14 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 17 Jan 2024 07:46:59 +0000 Subject: [PATCH 11/17] add meta back Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/engines/evaluator.py | 18 ++++++++++++++++-- monai/engines/trainer.py | 18 ++++++++++++------ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index ccae3de883..7e27f5033c 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -306,8 +306,13 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten inputs, targets, args, kwargs = batch # FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026 if self.compile: - inputs = torch.Tensor(inputs) if isinstance(inputs, MetaTensor) else inputs - targets = torch.Tensor(targets) if isinstance(targets, MetaTensor) else targets + inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None + if isinstance(inputs, MetaTensor): + warnings.warn("Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass.") + inputs, inputs_meta, inputs_applied_operations = inputs.as_tensor(), inputs.meta, inputs.applied_operations + if isinstance(targets, MetaTensor): + targets, targets_meta, targets_applied_operations = targets.as_tensor(), targets.meta, targets.applied_operations + # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # execute forward computation @@ -317,6 +322,15 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) else: engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) + # copy back meta info + if self.compile: + if inputs_meta is not None: + engine.state.output[Keys.IMAGE] = MetaTensor(inputs, meta=inputs_meta, applied_operations=inputs_applied_operations) + engine.state.output[Keys.PRED] = MetaTensor( + engine.state.output[Keys.PRED], meta=inputs_meta, applied_operations=inputs_applied_operations + ) + if targets_meta is not None: + engine.state.output[Keys.LABEL] = MetaTensor(targets, meta=targets_meta, applied_operations=targets_applied_operations) engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index d0739d5260..b7323ede58 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -225,11 +225,12 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso inputs, targets, args, kwargs = batch # FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026 if self.compile: - inputs_meta, targets_meta = None, None + inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None if isinstance(inputs, MetaTensor): - inputs, inputs_meta = inputs.as_tensor(), inputs.meta + warnings.warn("Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass.") + inputs, inputs_meta, inputs_applied_operations = inputs.as_tensor(), inputs.meta, inputs.applied_operations if isinstance(targets, MetaTensor): - targets, targets_meta = targets.as_tensor(), targets.meta + targets, targets_meta, targets_applied_operations = targets.as_tensor(), targets.meta, targets.applied_operations # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} @@ -255,12 +256,17 @@ def _compute_pred_loss(): engine.state.output[Keys.LOSS].backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) engine.optimizer.step() - engine.fire_event(IterationEvents.MODEL_COMPLETED) + # copy back meta info if self.compile: if inputs_meta is not None: - engine.state.output[Keys.IMAGE] = MetaTensor(inputs, meta=inputs_meta) + engine.state.output[Keys.IMAGE] = MetaTensor(inputs, meta=inputs_meta, applied_operations=inputs_applied_operations) + engine.state.output[Keys.PRED] = MetaTensor( + engine.state.output[Keys.PRED], meta=inputs_meta, applied_operations=inputs_applied_operations + ) if targets_meta is not None: - engine.state.output[Keys.LABEL] = MetaTensor(targets, meta=targets_meta) + engine.state.output[Keys.LABEL] = MetaTensor(targets, meta=targets_meta, applied_operations=targets_applied_operations) + engine.fire_event(IterationEvents.MODEL_COMPLETED) + return engine.state.output From 2ba99df2153fa183ce3a1500248927c0eb438350 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 17 Jan 2024 15:48:19 +0800 Subject: [PATCH 12/17] fix flake8 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/engines/evaluator.py | 24 +++++++++++++++++++----- monai/engines/trainer.py | 24 +++++++++++++++++++----- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 7e27f5033c..40c873ad26 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -308,10 +308,20 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten if self.compile: inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None if isinstance(inputs, MetaTensor): - warnings.warn("Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass.") - inputs, inputs_meta, inputs_applied_operations = inputs.as_tensor(), inputs.meta, inputs.applied_operations + warnings.warn( + "Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass." + ) + inputs, inputs_meta, inputs_applied_operations = ( + inputs.as_tensor(), + inputs.meta, + inputs.applied_operations, + ) if isinstance(targets, MetaTensor): - targets, targets_meta, targets_applied_operations = targets.as_tensor(), targets.meta, targets.applied_operations + targets, targets_meta, targets_applied_operations = ( + targets.as_tensor(), + targets.meta, + targets.applied_operations, + ) # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} @@ -325,12 +335,16 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten # copy back meta info if self.compile: if inputs_meta is not None: - engine.state.output[Keys.IMAGE] = MetaTensor(inputs, meta=inputs_meta, applied_operations=inputs_applied_operations) + engine.state.output[Keys.IMAGE] = MetaTensor( + inputs, meta=inputs_meta, applied_operations=inputs_applied_operations + ) engine.state.output[Keys.PRED] = MetaTensor( engine.state.output[Keys.PRED], meta=inputs_meta, applied_operations=inputs_applied_operations ) if targets_meta is not None: - engine.state.output[Keys.LABEL] = MetaTensor(targets, meta=targets_meta, applied_operations=targets_applied_operations) + engine.state.output[Keys.LABEL] = MetaTensor( + targets, meta=targets_meta, applied_operations=targets_applied_operations + ) engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index b7323ede58..8b06dec6a5 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -227,10 +227,20 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso if self.compile: inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None if isinstance(inputs, MetaTensor): - warnings.warn("Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass.") - inputs, inputs_meta, inputs_applied_operations = inputs.as_tensor(), inputs.meta, inputs.applied_operations + warnings.warn( + "Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass." + ) + inputs, inputs_meta, inputs_applied_operations = ( + inputs.as_tensor(), + inputs.meta, + inputs.applied_operations, + ) if isinstance(targets, MetaTensor): - targets, targets_meta, targets_applied_operations = targets.as_tensor(), targets.meta, targets.applied_operations + targets, targets_meta, targets_applied_operations = ( + targets.as_tensor(), + targets.meta, + targets.applied_operations, + ) # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} @@ -259,12 +269,16 @@ def _compute_pred_loss(): # copy back meta info if self.compile: if inputs_meta is not None: - engine.state.output[Keys.IMAGE] = MetaTensor(inputs, meta=inputs_meta, applied_operations=inputs_applied_operations) + engine.state.output[Keys.IMAGE] = MetaTensor( + inputs, meta=inputs_meta, applied_operations=inputs_applied_operations + ) engine.state.output[Keys.PRED] = MetaTensor( engine.state.output[Keys.PRED], meta=inputs_meta, applied_operations=inputs_applied_operations ) if targets_meta is not None: - engine.state.output[Keys.LABEL] = MetaTensor(targets, meta=targets_meta, applied_operations=targets_applied_operations) + engine.state.output[Keys.LABEL] = MetaTensor( + targets, meta=targets_meta, applied_operations=targets_applied_operations + ) engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output From c9b811568f89647c754f146635cdc206ee4c7809 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 17 Jan 2024 15:53:02 +0800 Subject: [PATCH 13/17] fix flake8 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/engines/evaluator.py | 4 ++-- monai/engines/trainer.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 40c873ad26..8d3724508f 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -215,8 +215,8 @@ class SupervisedEvaluator(Evaluator): `device`, `non_blocking`. amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. - compile: whether to use compile, default is False. - If set to True, the inputs will be converted to `torch.Tensor` internally. + compile: whether to use compile, default is False. If set to True, the inputs will be converted to `torch.Tensor` + and copy back the meta information after forward pass internally. compile_kwargs: dict of the args for `torch.compile()` API, for more details: https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile. diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 8b06dec6a5..8913a29f3b 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -128,8 +128,8 @@ class SupervisedTrainer(Trainer): `device`, `non_blocking`. amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. - compile: whether to use compile, default is False. - If set to True, the inputs will be converted to `torch.Tensor` internally. + compile: whether to use compile, default is False. If set to True, the inputs will be converted to `torch.Tensor` + and copy back the meta information after forward pass internally. compile_kwargs: dict of the args for `torch.compile()` API, for more details: https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile. """ From 02e088b205c24847478577daf700c0ef3d782c64 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 18 Jan 2024 12:48:57 +0800 Subject: [PATCH 14/17] Update monai/engines/evaluator.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/engines/evaluator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 8d3724508f..08ed4a25ab 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -215,8 +215,8 @@ class SupervisedEvaluator(Evaluator): `device`, `non_blocking`. amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. - compile: whether to use compile, default is False. If set to True, the inputs will be converted to `torch.Tensor` - and copy back the meta information after forward pass internally. + compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to + `torch.Tensor` before forward pass, then converted back afterward with copied meta information. compile_kwargs: dict of the args for `torch.compile()` API, for more details: https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile. From f3ae748c027a30fa61e938f07d07abafe360f551 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Jan 2024 04:49:19 +0000 Subject: [PATCH 15/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/engines/evaluator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 08ed4a25ab..1614b67cbb 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -215,7 +215,7 @@ class SupervisedEvaluator(Evaluator): `device`, `non_blocking`. amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. - compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to + compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to `torch.Tensor` before forward pass, then converted back afterward with copied meta information. compile_kwargs: dict of the args for `torch.compile()` API, for more details: https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile. From 7e212111cad6b528441a3c49618fff306b9eae69 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 18 Jan 2024 12:50:19 +0800 Subject: [PATCH 16/17] address comments Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/engines/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 8913a29f3b..5022dd4530 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -128,8 +128,8 @@ class SupervisedTrainer(Trainer): `device`, `non_blocking`. amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. - compile: whether to use compile, default is False. If set to True, the inputs will be converted to `torch.Tensor` - and copy back the meta information after forward pass internally. + compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to + `torch.Tensor` before forward pass, then converted back afterward with copied meta information. compile_kwargs: dict of the args for `torch.compile()` API, for more details: https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile. """ From 10051df7a46b4678a6bc7ea751f9986f8a212854 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 19 Jan 2024 10:28:01 +0800 Subject: [PATCH 17/17] fix mypy Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/engines/evaluator.py | 2 +- monai/engines/trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 1614b67cbb..2c8dfe6b85 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -270,7 +270,7 @@ def __init__( if compile: if pytorch_after(2, 1): compile_kwargs = {} if compile_kwargs is None else compile_kwargs - network = torch.compile(network, **compile_kwargs) + network = torch.compile(network, **compile_kwargs) # type: ignore[assignment] else: warnings.warn( "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done" diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 5022dd4530..f1513ea73b 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -185,7 +185,7 @@ def __init__( if compile: if pytorch_after(2, 1): compile_kwargs = {} if compile_kwargs is None else compile_kwargs - network = torch.compile(network, **compile_kwargs) + network = torch.compile(network, **compile_kwargs) # type: ignore[assignment] else: warnings.warn( "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done"