diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py index 77e271a9eb..8a64ad7cf9 100644 --- a/monai/apps/deepgrow/interaction.py +++ b/monai/apps/deepgrow/interaction.py @@ -13,9 +13,9 @@ import torch from monai.engines import SupervisedEvaluator, SupervisedTrainer -from monai.engines.utils import CommonKeys from monai.engines.workflow import Events from monai.transforms import Compose +from monai.utils.enums import CommonKeys class Interaction: diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index 8256680735..d3a14f6104 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -12,4 +12,4 @@ from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator from .multi_gpu_supervised_trainer import create_multigpu_supervised_evaluator, create_multigpu_supervised_trainer from .trainer import GanTrainer, SupervisedTrainer, Trainer -from .utils import CommonKeys, GanKeys, IterationEvents, default_make_latent, default_prepare_batch, get_devices_spec +from .utils import GanKeys, IterationEvents, default_make_latent, default_prepare_batch, get_devices_spec diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 0afa3747a4..2c237f5245 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -14,13 +14,13 @@ import torch from torch.utils.data import DataLoader -from monai.engines.utils import CommonKeys as Keys from monai.engines.utils import IterationEvents, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.networks.utils import eval_mode from monai.transforms import Transform from monai.utils import ensure_tuple, exact_version, optional_import +from monai.utils.enums import CommonKeys as Keys if TYPE_CHECKING: from ignite.engine import Engine diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 5b996eafe1..a7b1943211 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -15,12 +15,12 @@ from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader -from monai.engines.utils import CommonKeys as Keys from monai.engines.utils import GanKeys, IterationEvents, default_make_latent, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform from monai.utils import exact_version, optional_import +from monai.utils.enums import CommonKeys as Keys if TYPE_CHECKING: from ignite.engine import Engine diff --git a/monai/engines/utils.py b/monai/engines/utils.py index b0b1e44f71..04237d0f4a 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -14,6 +14,7 @@ import torch from monai.utils import exact_version, optional_import +from monai.utils.enums import CommonKeys if TYPE_CHECKING: from ignite.engine import EventEnum @@ -22,7 +23,6 @@ __all__ = [ "IterationEvents", - "CommonKeys", "GanKeys", "get_devices_spec", "default_prepare_batch", @@ -47,23 +47,6 @@ class IterationEvents(EventEnum): OPTIMIZER_COMPLETED = "optimizer_completed" -class CommonKeys: - """ - A set of common keys for dictionary based supervised training process. - `IMAGE` is the input image data. - `LABEL` is the training or evaluation label of segmentation or classification task. - `PRED` is the prediction data of model output. - `LOSS` is the loss value of current iteration. - `INFO` is some useful information during training or evaluation, like loss value, etc. - - """ - - IMAGE = "image" - LABEL = "label" - PRED = "pred" - LOSS = "loss" - - class GanKeys: """ A set of common keys for generative adversarial networks. diff --git a/monai/handlers/metric_logger.py b/monai/handlers/metric_logger.py index c749d4bbab..0cfefb715a 100644 --- a/monai/handlers/metric_logger.py +++ b/monai/handlers/metric_logger.py @@ -14,8 +14,8 @@ from threading import RLock from typing import TYPE_CHECKING, Callable, DefaultDict, List, Optional -from monai.engines.utils import CommonKeys from monai.utils import exact_version, optional_import +from monai.utils.enums import CommonKeys Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") if TYPE_CHECKING: diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 4d272ac6ff..6a76c96d0c 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -17,6 +17,7 @@ Average, BlendMode, ChannelMatching, + CommonKeys, GridSampleMode, GridSamplePadMode, InterpolateMode, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 63d65329af..9920aefe0e 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -29,6 +29,7 @@ "SkipMode", "Method", "InverseKeys", + "CommonKeys", ] @@ -226,3 +227,20 @@ class InverseKeys: EXTRA_INFO = "extra_info" DO_TRANSFORM = "do_transforms" KEY_SUFFIX = "_transforms" + + +class CommonKeys: + """ + A set of common keys for dictionary based supervised training process. + `IMAGE` is the input image data. + `LABEL` is the training or evaluation label of segmentation or classification task. + `PRED` is the prediction data of model output. + `LOSS` is the loss value of current iteration. + `INFO` is some useful information during training or evaluation, like loss value, etc. + + """ + + IMAGE = "image" + LABEL = "label" + PRED = "pred" + LOSS = "loss" diff --git a/tests/test_threadcontainer.py b/tests/test_threadcontainer.py index 92a50a15aa..13608e166c 100644 --- a/tests/test_threadcontainer.py +++ b/tests/test_threadcontainer.py @@ -15,11 +15,12 @@ import torch from monai.utils import optional_import +from monai.utils.enums import CommonKeys try: _, has_ignite = optional_import("ignite") - from monai.engines import CommonKeys, SupervisedTrainer + from monai.engines import SupervisedTrainer from monai.utils import ThreadContainer except ImportError: has_ignite = False