From 811dc5354cf2f8dc076c2154b11c0917b66eba9a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 7 Apr 2021 14:04:34 +0800 Subject: [PATCH 1/5] [DLMED] add support for addtional events Signed-off-by: Nic Ma --- monai/engines/evaluator.py | 39 +++++++++++++++++++++++++++++--------- monai/engines/trainer.py | 19 +++++++++++++------ monai/engines/workflow.py | 27 +++++++++++++------------- 3 files changed, 56 insertions(+), 29 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index c1fe79c848..3e67cc23c1 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable, Dict, Iterable, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, Iterable, Optional, Sequence, Tuple, Union, List import torch from torch.utils.data import DataLoader @@ -22,6 +22,8 @@ from monai.utils import ForwardMode, ensure_tuple, exact_version, optional_import from monai.utils.enums import CommonKeys as Keys +EventEnum, _ = optional_import("ignite.engine", "0.4.4", exact_version, "EventEnum") + if TYPE_CHECKING: from ignite.engine import Engine from ignite.metrics import Metric @@ -56,6 +58,10 @@ class Evaluator(Workflow): amp: whether to enable auto-mixed-precision evaluation, default is False. mode: model forward mode during evaluation, should be 'eval' or 'train', which maps to `model.eval()` or `model.train()`, default to 'eval'. + additional_events: addtional custom ignite events that will register to the engine. + new events can be a str or an object derived from `ignite.engine.events.EventEnum`. + additional_event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 """ @@ -73,6 +79,8 @@ def __init__( val_handlers: Optional[Sequence] = None, amp: bool = False, mode: Union[ForwardMode, str] = ForwardMode.EVAL, + *additional_events: Union[List[str], List[EventEnum]], + additional_event_to_attr: Optional[dict] = None, ) -> None: super().__init__( device=device, @@ -87,6 +95,8 @@ def __init__( additional_metrics=additional_metrics, handlers=val_handlers, amp=amp, + *additional_events, + additional_event_to_attr=additional_event_to_attr, ) mode = ForwardMode(mode) if mode == ForwardMode.EVAL: @@ -140,6 +150,10 @@ class SupervisedEvaluator(Evaluator): amp: whether to enable auto-mixed-precision evaluation, default is False. mode: model forward mode during evaluation, should be 'eval' or 'train', which maps to `model.eval()` or `model.train()`, default to 'eval'. + additional_events: addtional custom ignite events that will register to the engine. + new events can be a str or an object derived from `ignite.engine.events.EventEnum`. + additional_event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 """ @@ -159,7 +173,11 @@ def __init__( val_handlers: Optional[Sequence] = None, amp: bool = False, mode: Union[ForwardMode, str] = ForwardMode.EVAL, + *additional_events: Union[List[str], List[EventEnum]], + additional_event_to_attr: Optional[dict] = None, ) -> None: + # add the iteration events + self.register_events(*IterationEvents) super().__init__( device=device, val_data_loader=val_data_loader, @@ -173,15 +191,12 @@ def __init__( val_handlers=val_handlers, amp=amp, mode=mode, + additional_event_to_attr=additional_event_to_attr, ) self.network = network self.inferer = SimpleInferer() if inferer is None else inferer - def _register_additional_events(self): - super()._register_additional_events() - self.register_events(*IterationEvents) - def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. @@ -251,6 +266,10 @@ class EnsembleEvaluator(Evaluator): amp: whether to enable auto-mixed-precision evaluation, default is False. mode: model forward mode during evaluation, should be 'eval' or 'train', which maps to `model.eval()` or `model.train()`, default to 'eval'. + additional_events: addtional custom ignite events that will register to the engine. + new events can be a str or an object derived from `ignite.engine.events.EventEnum`. + additional_event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 """ @@ -271,7 +290,11 @@ def __init__( val_handlers: Optional[Sequence] = None, amp: bool = False, mode: Union[ForwardMode, str] = ForwardMode.EVAL, + *additional_events: Union[List[str], List[EventEnum]], + additional_event_to_attr: Optional[dict] = None, ) -> None: + # add the iteration events + self.register_events(*IterationEvents) super().__init__( device=device, val_data_loader=val_data_loader, @@ -285,16 +308,14 @@ def __init__( val_handlers=val_handlers, amp=amp, mode=mode, + *additional_events, + additional_event_to_attr=additional_event_to_attr, ) self.networks = ensure_tuple(networks) self.pred_keys = ensure_tuple(pred_keys) self.inferer = SimpleInferer() if inferer is None else inferer - def _register_additional_events(self): - super()._register_additional_events() - self.register_events(*IterationEvents) - def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index a7b1943211..74bd6c9eb4 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable, Dict, Iterable, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, Iterable, Optional, Sequence, Tuple, Union, List import torch from torch.optim.optimizer import Optimizer @@ -22,6 +22,8 @@ from monai.utils import exact_version, optional_import from monai.utils.enums import CommonKeys as Keys +EventEnum, _ = optional_import("ignite.engine", "0.4.4", exact_version, "EventEnum") + if TYPE_CHECKING: from ignite.engine import Engine from ignite.metrics import Metric @@ -78,6 +80,10 @@ class SupervisedTrainer(Trainer): train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision training, default is False. + additional_events: addtional custom ignite events that will register to the engine. + new events can be a str or an object derived from `ignite.engine.events.EventEnum`. + additional_event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 """ @@ -99,8 +105,11 @@ def __init__( additional_metrics: Optional[Dict[str, Metric]] = None, train_handlers: Optional[Sequence] = None, amp: bool = False, + *additional_events: Union[List[str], List[EventEnum]], + additional_event_to_attr: Optional[dict] = None, ) -> None: - # set up Ignite engine and environments + # add the iteration events + self.register_events(*IterationEvents) super().__init__( device=device, max_epochs=max_epochs, @@ -114,6 +123,8 @@ def __init__( additional_metrics=additional_metrics, handlers=train_handlers, amp=amp, + *additional_events, + additional_event_to_attr=additional_event_to_attr, ) self.network = network @@ -121,10 +132,6 @@ def __init__( self.loss_function = loss_function self.inferer = SimpleInferer() if inferer is None else inferer - def _register_additional_events(self): - super()._register_additional_events() - self.register_events(*IterationEvents) - def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine. diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 61b92ac5dd..3cf90b933c 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable, Dict, Iterable, Optional, Sequence, Union +from typing import TYPE_CHECKING, Callable, Dict, Iterable, Optional, Sequence, Union, List import torch import torch.distributed as dist @@ -23,6 +23,8 @@ IgniteEngine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") State, _ = optional_import("ignite.engine", "0.4.4", exact_version, "State") Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") +EventEnum, _ = optional_import("ignite.engine", "0.4.4", exact_version, "EventEnum") + if TYPE_CHECKING: from ignite.engine import Engine from ignite.metrics import Metric @@ -60,6 +62,10 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision training or inference, default is False. + additional_events: addtional custom ignite events that will register to the engine. + new events can be a str or an object derived from `ignite.engine.events.EventEnum`. + additional_event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 Raises: TypeError: When ``device`` is not a ``torch.Device``. @@ -83,6 +89,8 @@ def __init__( additional_metrics: Optional[Dict[str, Metric]] = None, handlers: Optional[Sequence] = None, amp: bool = False, + *additional_events: Union[List[str], List[EventEnum]], + additional_event_to_attr: Optional[dict] = None, ) -> None: if iteration_update is not None: super().__init__(iteration_update) @@ -128,7 +136,7 @@ def set_sampler_epoch(engine: Engine): self.prepare_batch = prepare_batch self.amp = amp - self._register_additional_events() + self.register_events(*additional_events, event_to_attr=additional_event_to_attr) if post_transform is not None: self._register_post_transforms(post_transform) if key_metric is not None: @@ -136,14 +144,7 @@ def set_sampler_epoch(engine: Engine): if handlers is not None: self._register_handlers(handlers) - def _register_additional_events(self): - """ - Register more ignite Events to the engine. - - """ - pass - - def _register_post_transforms(self, posttrans): + def _register_post_transforms(self, posttrans: Callable): """ Register the post transforms to the engine, will execute them as a chain when iteration completed. @@ -151,11 +152,9 @@ def _register_post_transforms(self, posttrans): @self.on(Events.ITERATION_COMPLETED) def run_post_transform(engine: Engine) -> None: - if posttrans is None: - raise AssertionError engine.state.output = apply_transform(posttrans, engine.state.output) - def _register_metrics(self, k_metric, add_metrics): + def _register_metrics(self, k_metric: Dict, add_metrics: Optional[Dict] = None): """ Register the key metric and additional metrics to the engine, supports ignite Metrics. @@ -180,7 +179,7 @@ def _compare_metrics(engine: Engine) -> None: engine.state.best_metric = current_val_metric engine.state.best_metric_epoch = engine.state.epoch - def _register_handlers(self, handlers): + def _register_handlers(self, handlers: Sequence): """ Register the handlers to the engine, supports ignite Handlers with `attach` API. From 68026e67666f2ba209df874c5d9216ef4d5a6c58 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 7 Apr 2021 19:06:29 +0800 Subject: [PATCH 2/5] [DLMED] add unit tests Signed-off-by: Nic Ma --- monai/engines/evaluator.py | 49 ++++++++++++++++---------------- monai/engines/trainer.py | 19 ++++++------- monai/engines/workflow.py | 24 +++++++++++----- tests/test_ensemble_evaluator.py | 23 ++++++++++++++- 4 files changed, 72 insertions(+), 43 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 3e67cc23c1..f90e972789 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable, Dict, Iterable, Optional, Sequence, Tuple, Union, List +from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import torch from torch.utils.data import DataLoader @@ -58,9 +58,9 @@ class Evaluator(Workflow): amp: whether to enable auto-mixed-precision evaluation, default is False. mode: model forward mode during evaluation, should be 'eval' or 'train', which maps to `model.eval()` or `model.train()`, default to 'eval'. - additional_events: addtional custom ignite events that will register to the engine. - new events can be a str or an object derived from `ignite.engine.events.EventEnum`. - additional_event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + event_names: addtional custom ignite events that will register to the engine. + new events can be a list of str or `ignite.engine.events.EventEnum`. + event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 """ @@ -79,8 +79,8 @@ def __init__( val_handlers: Optional[Sequence] = None, amp: bool = False, mode: Union[ForwardMode, str] = ForwardMode.EVAL, - *additional_events: Union[List[str], List[EventEnum]], - additional_event_to_attr: Optional[dict] = None, + event_names: Optional[List[Union[str, EventEnum]]] = None, + event_to_attr: Optional[dict] = None, ) -> None: super().__init__( device=device, @@ -95,8 +95,8 @@ def __init__( additional_metrics=additional_metrics, handlers=val_handlers, amp=amp, - *additional_events, - additional_event_to_attr=additional_event_to_attr, + event_names=event_names, + event_to_attr=event_to_attr, ) mode = ForwardMode(mode) if mode == ForwardMode.EVAL: @@ -150,9 +150,9 @@ class SupervisedEvaluator(Evaluator): amp: whether to enable auto-mixed-precision evaluation, default is False. mode: model forward mode during evaluation, should be 'eval' or 'train', which maps to `model.eval()` or `model.train()`, default to 'eval'. - additional_events: addtional custom ignite events that will register to the engine. - new events can be a str or an object derived from `ignite.engine.events.EventEnum`. - additional_event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + event_names: addtional custom ignite events that will register to the engine. + new events can be a list of str or `ignite.engine.events.EventEnum`. + event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 """ @@ -173,11 +173,9 @@ def __init__( val_handlers: Optional[Sequence] = None, amp: bool = False, mode: Union[ForwardMode, str] = ForwardMode.EVAL, - *additional_events: Union[List[str], List[EventEnum]], - additional_event_to_attr: Optional[dict] = None, + event_names: Optional[List[Union[str, EventEnum]]] = None, + event_to_attr: Optional[dict] = None, ) -> None: - # add the iteration events - self.register_events(*IterationEvents) super().__init__( device=device, val_data_loader=val_data_loader, @@ -191,7 +189,9 @@ def __init__( val_handlers=val_handlers, amp=amp, mode=mode, - additional_event_to_attr=additional_event_to_attr, + # add the iteration events + event_names=[IterationEvents] if event_names is None else event_names + [IterationEvents], + event_to_attr=event_to_attr, ) self.network = network @@ -266,9 +266,9 @@ class EnsembleEvaluator(Evaluator): amp: whether to enable auto-mixed-precision evaluation, default is False. mode: model forward mode during evaluation, should be 'eval' or 'train', which maps to `model.eval()` or `model.train()`, default to 'eval'. - additional_events: addtional custom ignite events that will register to the engine. - new events can be a str or an object derived from `ignite.engine.events.EventEnum`. - additional_event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + event_names: addtional custom ignite events that will register to the engine. + new events can be a list of str or `ignite.engine.events.EventEnum`. + event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 """ @@ -290,11 +290,9 @@ def __init__( val_handlers: Optional[Sequence] = None, amp: bool = False, mode: Union[ForwardMode, str] = ForwardMode.EVAL, - *additional_events: Union[List[str], List[EventEnum]], - additional_event_to_attr: Optional[dict] = None, + event_names: Optional[List[Union[str, EventEnum]]] = None, + event_to_attr: Optional[dict] = None, ) -> None: - # add the iteration events - self.register_events(*IterationEvents) super().__init__( device=device, val_data_loader=val_data_loader, @@ -308,8 +306,9 @@ def __init__( val_handlers=val_handlers, amp=amp, mode=mode, - *additional_events, - additional_event_to_attr=additional_event_to_attr, + # add the iteration events + event_names=[IterationEvents] if event_names is None else event_names + [IterationEvents], + event_to_attr=event_to_attr, ) self.networks = ensure_tuple(networks) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 74bd6c9eb4..9a9da04884 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable, Dict, Iterable, Optional, Sequence, Tuple, Union, List +from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import torch from torch.optim.optimizer import Optimizer @@ -80,9 +80,9 @@ class SupervisedTrainer(Trainer): train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision training, default is False. - additional_events: addtional custom ignite events that will register to the engine. - new events can be a str or an object derived from `ignite.engine.events.EventEnum`. - additional_event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + event_names: addtional custom ignite events that will register to the engine. + new events can be a list of str or `ignite.engine.events.EventEnum`. + event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 """ @@ -105,11 +105,9 @@ def __init__( additional_metrics: Optional[Dict[str, Metric]] = None, train_handlers: Optional[Sequence] = None, amp: bool = False, - *additional_events: Union[List[str], List[EventEnum]], - additional_event_to_attr: Optional[dict] = None, + event_names: Optional[List[Union[str, EventEnum]]] = None, + event_to_attr: Optional[dict] = None, ) -> None: - # add the iteration events - self.register_events(*IterationEvents) super().__init__( device=device, max_epochs=max_epochs, @@ -123,8 +121,9 @@ def __init__( additional_metrics=additional_metrics, handlers=train_handlers, amp=amp, - *additional_events, - additional_event_to_attr=additional_event_to_attr, + # add the iteration events + event_names=[IterationEvents] if event_names is None else event_names + [IterationEvents], + event_to_attr=event_to_attr, ) self.network = network diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 3cf90b933c..8112eafa81 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable, Dict, Iterable, Optional, Sequence, Union, List +from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Union import torch import torch.distributed as dist @@ -62,9 +62,9 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision training or inference, default is False. - additional_events: addtional custom ignite events that will register to the engine. - new events can be a str or an object derived from `ignite.engine.events.EventEnum`. - additional_event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + event_names: addtional custom ignite events that will register to the engine. + new events can be a list of str or `ignite.engine.events.EventEnum`. + event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 Raises: @@ -89,8 +89,8 @@ def __init__( additional_metrics: Optional[Dict[str, Metric]] = None, handlers: Optional[Sequence] = None, amp: bool = False, - *additional_events: Union[List[str], List[EventEnum]], - additional_event_to_attr: Optional[dict] = None, + event_names: Optional[List[Union[str, EventEnum]]] = None, + event_to_attr: Optional[dict] = None, ) -> None: if iteration_update is not None: super().__init__(iteration_update) @@ -136,7 +136,17 @@ def set_sampler_epoch(engine: Engine): self.prepare_batch = prepare_batch self.amp = amp - self.register_events(*additional_events, event_to_attr=additional_event_to_attr) + if event_names is not None: + if not isinstance(event_names, list): + raise ValueError("event_names must be a list or string or EventEnum.") + for name in event_names: + if isinstance(name, str): + self.register_events(name, event_to_attr=event_to_attr) + elif issubclass(name, EventEnum): + self.register_events(*name, event_to_attr=event_to_attr) + else: + raise ValueError("event_names must be a list or string or EventEnum.") + if post_transform is not None: self._register_post_transforms(post_transform) if key_metric is not None: diff --git a/tests/test_ensemble_evaluator.py b/tests/test_ensemble_evaluator.py index 9cc977d876..28a2d4f941 100644 --- a/tests/test_ensemble_evaluator.py +++ b/tests/test_ensemble_evaluator.py @@ -12,7 +12,7 @@ import unittest import torch -from ignite.engine import Events +from ignite.engine import EventEnum, Events from monai.engines import EnsembleEvaluator @@ -44,11 +44,17 @@ def forward(self, x): net3 = TestNet(lambda x: x + 4) net4 = TestNet(lambda x: x + 5) + class CustomEvents(EventEnum): + FOO_EVENT = "foo_event" + BAR_EVENT = "bar_event" + val_engine = EnsembleEvaluator( device=device, val_data_loader=val_loader, networks=[net0, net1, net2, net3, net4], pred_keys=["pred0", "pred1", "pred2", "pred3", "pred4"], + event_names=["bwd_event", "opt_event", CustomEvents], + event_to_attr={CustomEvents.FOO_EVENT: "foo", "opt_event": "opt"}, ) @val_engine.on(Events.ITERATION_COMPLETED) @@ -57,6 +63,21 @@ def run_post_transform(engine): expected_value = engine.state.iteration + i torch.testing.assert_allclose(engine.state.output[f"pred{i}"], torch.tensor([[expected_value]])) + @val_engine.on(Events.EPOCH_COMPLETED) + def trigger_custom_event(): + val_engine.fire_event(CustomEvents.FOO_EVENT) + val_engine.fire_event(CustomEvents.BAR_EVENT) + val_engine.fire_event("bwd_event") + val_engine.fire_event("opt_event") + + @val_engine.on(CustomEvents.FOO_EVENT) + def do_foo_op(): + self.assertEqual(val_engine.state.foo, 0) + + @val_engine.on("opt_event") + def do_bar_op(): + self.assertEqual(val_engine.state.opt, 0) + val_engine.run() From 3e9a3911485f3ca54782a9aa055561c190e0e287 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 7 Apr 2021 19:11:06 +0800 Subject: [PATCH 3/5] [DLMED] fix typehints Signed-off-by: Nic Ma --- monai/engines/evaluator.py | 5 ++--- monai/engines/trainer.py | 5 ++--- monai/engines/workflow.py | 4 ++-- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index f90e972789..bc267a1d20 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -22,14 +22,13 @@ from monai.utils import ForwardMode, ensure_tuple, exact_version, optional_import from monai.utils.enums import CommonKeys as Keys -EventEnum, _ = optional_import("ignite.engine", "0.4.4", exact_version, "EventEnum") - if TYPE_CHECKING: - from ignite.engine import Engine + from ignite.engine import Engine, EventEnum from ignite.metrics import Metric else: Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric") + EventEnum, _ = optional_import("ignite.engine", "0.4.4", exact_version, "EventEnum") __all__ = ["Evaluator", "SupervisedEvaluator", "EnsembleEvaluator"] diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 9a9da04884..819531f336 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -22,14 +22,13 @@ from monai.utils import exact_version, optional_import from monai.utils.enums import CommonKeys as Keys -EventEnum, _ = optional_import("ignite.engine", "0.4.4", exact_version, "EventEnum") - if TYPE_CHECKING: - from ignite.engine import Engine + from ignite.engine import Engine, EventEnum from ignite.metrics import Metric else: Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric") + EventEnum, _ = optional_import("ignite.engine", "0.4.4", exact_version, "EventEnum") __all__ = ["Trainer", "SupervisedTrainer", "GanTrainer"] diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 8112eafa81..af2b3e0530 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -23,14 +23,14 @@ IgniteEngine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") State, _ = optional_import("ignite.engine", "0.4.4", exact_version, "State") Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") -EventEnum, _ = optional_import("ignite.engine", "0.4.4", exact_version, "EventEnum") if TYPE_CHECKING: - from ignite.engine import Engine + from ignite.engine import Engine, EventEnum from ignite.metrics import Metric else: Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric") + EventEnum, _ = optional_import("ignite.engine", "0.4.4", exact_version, "EventEnum") class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optional_import From 094fe33397ba9a75b0d3378a48f3b0fc1a457f89 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Wed, 7 Apr 2021 11:16:23 +0000 Subject: [PATCH 4/5] [MONAI] python code formatting Signed-off-by: monai-bot --- tests/test_handler_garbage_collector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_handler_garbage_collector.py b/tests/test_handler_garbage_collector.py index 9f63211a13..c2c5dcbfd6 100644 --- a/tests/test_handler_garbage_collector.py +++ b/tests/test_handler_garbage_collector.py @@ -67,7 +67,8 @@ def _train_func(engine, batch): self.assertGreater(gb_count[0], 0) if iter > 1: # Since we are collecting all objects from all generations manually at each call, - # starting from the second call, there shouldn't be any 1st and 2nd generation objects available to collect. + # starting from the second call, there shouldn't be any 1st and 2nd + # generation objects available to collect. self.assertEqual(gb_count[1], first_count) self.assertEqual(gb_count[2], first_count) From b18e1fd6011a36c6e69e0b50311dbb98ba3998c3 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 7 Apr 2021 13:44:31 +0100 Subject: [PATCH 5/5] fixes typos Signed-off-by: Wenqi Li --- monai/engines/evaluator.py | 6 +++--- monai/engines/trainer.py | 2 +- monai/engines/workflow.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index bc267a1d20..bfa69c0bdd 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -57,7 +57,7 @@ class Evaluator(Workflow): amp: whether to enable auto-mixed-precision evaluation, default is False. mode: model forward mode during evaluation, should be 'eval' or 'train', which maps to `model.eval()` or `model.train()`, default to 'eval'. - event_names: addtional custom ignite events that will register to the engine. + event_names: additional custom ignite events that will register to the engine. new events can be a list of str or `ignite.engine.events.EventEnum`. event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 @@ -149,7 +149,7 @@ class SupervisedEvaluator(Evaluator): amp: whether to enable auto-mixed-precision evaluation, default is False. mode: model forward mode during evaluation, should be 'eval' or 'train', which maps to `model.eval()` or `model.train()`, default to 'eval'. - event_names: addtional custom ignite events that will register to the engine. + event_names: additional custom ignite events that will register to the engine. new events can be a list of str or `ignite.engine.events.EventEnum`. event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 @@ -265,7 +265,7 @@ class EnsembleEvaluator(Evaluator): amp: whether to enable auto-mixed-precision evaluation, default is False. mode: model forward mode during evaluation, should be 'eval' or 'train', which maps to `model.eval()` or `model.train()`, default to 'eval'. - event_names: addtional custom ignite events that will register to the engine. + event_names: additional custom ignite events that will register to the engine. new events can be a list of str or `ignite.engine.events.EventEnum`. event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 819531f336..f14ee7e91f 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -79,7 +79,7 @@ class SupervisedTrainer(Trainer): train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision training, default is False. - event_names: addtional custom ignite events that will register to the engine. + event_names: additional custom ignite events that will register to the engine. new events can be a list of str or `ignite.engine.events.EventEnum`. event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index af2b3e0530..50a9f41368 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -62,7 +62,7 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision training or inference, default is False. - event_names: addtional custom ignite events that will register to the engine. + event_names: additional custom ignite events that will register to the engine. new events can be a list of str or `ignite.engine.events.EventEnum`. event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160