diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index c1fe79c848..bfa69c0bdd 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable, Dict, Iterable, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import torch from torch.utils.data import DataLoader @@ -23,11 +23,12 @@ from monai.utils.enums import CommonKeys as Keys if TYPE_CHECKING: - from ignite.engine import Engine + from ignite.engine import Engine, EventEnum from ignite.metrics import Metric else: Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric") + EventEnum, _ = optional_import("ignite.engine", "0.4.4", exact_version, "EventEnum") __all__ = ["Evaluator", "SupervisedEvaluator", "EnsembleEvaluator"] @@ -56,6 +57,10 @@ class Evaluator(Workflow): amp: whether to enable auto-mixed-precision evaluation, default is False. mode: model forward mode during evaluation, should be 'eval' or 'train', which maps to `model.eval()` or `model.train()`, default to 'eval'. + event_names: additional custom ignite events that will register to the engine. + new events can be a list of str or `ignite.engine.events.EventEnum`. + event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 """ @@ -73,6 +78,8 @@ def __init__( val_handlers: Optional[Sequence] = None, amp: bool = False, mode: Union[ForwardMode, str] = ForwardMode.EVAL, + event_names: Optional[List[Union[str, EventEnum]]] = None, + event_to_attr: Optional[dict] = None, ) -> None: super().__init__( device=device, @@ -87,6 +94,8 @@ def __init__( additional_metrics=additional_metrics, handlers=val_handlers, amp=amp, + event_names=event_names, + event_to_attr=event_to_attr, ) mode = ForwardMode(mode) if mode == ForwardMode.EVAL: @@ -140,6 +149,10 @@ class SupervisedEvaluator(Evaluator): amp: whether to enable auto-mixed-precision evaluation, default is False. mode: model forward mode during evaluation, should be 'eval' or 'train', which maps to `model.eval()` or `model.train()`, default to 'eval'. + event_names: additional custom ignite events that will register to the engine. + new events can be a list of str or `ignite.engine.events.EventEnum`. + event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 """ @@ -159,6 +172,8 @@ def __init__( val_handlers: Optional[Sequence] = None, amp: bool = False, mode: Union[ForwardMode, str] = ForwardMode.EVAL, + event_names: Optional[List[Union[str, EventEnum]]] = None, + event_to_attr: Optional[dict] = None, ) -> None: super().__init__( device=device, @@ -173,15 +188,14 @@ def __init__( val_handlers=val_handlers, amp=amp, mode=mode, + # add the iteration events + event_names=[IterationEvents] if event_names is None else event_names + [IterationEvents], + event_to_attr=event_to_attr, ) self.network = network self.inferer = SimpleInferer() if inferer is None else inferer - def _register_additional_events(self): - super()._register_additional_events() - self.register_events(*IterationEvents) - def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. @@ -251,6 +265,10 @@ class EnsembleEvaluator(Evaluator): amp: whether to enable auto-mixed-precision evaluation, default is False. mode: model forward mode during evaluation, should be 'eval' or 'train', which maps to `model.eval()` or `model.train()`, default to 'eval'. + event_names: additional custom ignite events that will register to the engine. + new events can be a list of str or `ignite.engine.events.EventEnum`. + event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 """ @@ -271,6 +289,8 @@ def __init__( val_handlers: Optional[Sequence] = None, amp: bool = False, mode: Union[ForwardMode, str] = ForwardMode.EVAL, + event_names: Optional[List[Union[str, EventEnum]]] = None, + event_to_attr: Optional[dict] = None, ) -> None: super().__init__( device=device, @@ -285,16 +305,15 @@ def __init__( val_handlers=val_handlers, amp=amp, mode=mode, + # add the iteration events + event_names=[IterationEvents] if event_names is None else event_names + [IterationEvents], + event_to_attr=event_to_attr, ) self.networks = ensure_tuple(networks) self.pred_keys = ensure_tuple(pred_keys) self.inferer = SimpleInferer() if inferer is None else inferer - def _register_additional_events(self): - super()._register_additional_events() - self.register_events(*IterationEvents) - def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index a7b1943211..f14ee7e91f 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable, Dict, Iterable, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import torch from torch.optim.optimizer import Optimizer @@ -23,11 +23,12 @@ from monai.utils.enums import CommonKeys as Keys if TYPE_CHECKING: - from ignite.engine import Engine + from ignite.engine import Engine, EventEnum from ignite.metrics import Metric else: Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric") + EventEnum, _ = optional_import("ignite.engine", "0.4.4", exact_version, "EventEnum") __all__ = ["Trainer", "SupervisedTrainer", "GanTrainer"] @@ -78,6 +79,10 @@ class SupervisedTrainer(Trainer): train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision training, default is False. + event_names: additional custom ignite events that will register to the engine. + new events can be a list of str or `ignite.engine.events.EventEnum`. + event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 """ @@ -99,8 +104,9 @@ def __init__( additional_metrics: Optional[Dict[str, Metric]] = None, train_handlers: Optional[Sequence] = None, amp: bool = False, + event_names: Optional[List[Union[str, EventEnum]]] = None, + event_to_attr: Optional[dict] = None, ) -> None: - # set up Ignite engine and environments super().__init__( device=device, max_epochs=max_epochs, @@ -114,6 +120,9 @@ def __init__( additional_metrics=additional_metrics, handlers=train_handlers, amp=amp, + # add the iteration events + event_names=[IterationEvents] if event_names is None else event_names + [IterationEvents], + event_to_attr=event_to_attr, ) self.network = network @@ -121,10 +130,6 @@ def __init__( self.loss_function = loss_function self.inferer = SimpleInferer() if inferer is None else inferer - def _register_additional_events(self): - super()._register_additional_events() - self.register_events(*IterationEvents) - def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine. diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 61b92ac5dd..50a9f41368 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable, Dict, Iterable, Optional, Sequence, Union +from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Union import torch import torch.distributed as dist @@ -23,12 +23,14 @@ IgniteEngine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") State, _ = optional_import("ignite.engine", "0.4.4", exact_version, "State") Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") + if TYPE_CHECKING: - from ignite.engine import Engine + from ignite.engine import Engine, EventEnum from ignite.metrics import Metric else: Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric") + EventEnum, _ = optional_import("ignite.engine", "0.4.4", exact_version, "EventEnum") class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optional_import @@ -60,6 +62,10 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision training or inference, default is False. + event_names: additional custom ignite events that will register to the engine. + new events can be a list of str or `ignite.engine.events.EventEnum`. + event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 Raises: TypeError: When ``device`` is not a ``torch.Device``. @@ -83,6 +89,8 @@ def __init__( additional_metrics: Optional[Dict[str, Metric]] = None, handlers: Optional[Sequence] = None, amp: bool = False, + event_names: Optional[List[Union[str, EventEnum]]] = None, + event_to_attr: Optional[dict] = None, ) -> None: if iteration_update is not None: super().__init__(iteration_update) @@ -128,7 +136,17 @@ def set_sampler_epoch(engine: Engine): self.prepare_batch = prepare_batch self.amp = amp - self._register_additional_events() + if event_names is not None: + if not isinstance(event_names, list): + raise ValueError("event_names must be a list or string or EventEnum.") + for name in event_names: + if isinstance(name, str): + self.register_events(name, event_to_attr=event_to_attr) + elif issubclass(name, EventEnum): + self.register_events(*name, event_to_attr=event_to_attr) + else: + raise ValueError("event_names must be a list or string or EventEnum.") + if post_transform is not None: self._register_post_transforms(post_transform) if key_metric is not None: @@ -136,14 +154,7 @@ def set_sampler_epoch(engine: Engine): if handlers is not None: self._register_handlers(handlers) - def _register_additional_events(self): - """ - Register more ignite Events to the engine. - - """ - pass - - def _register_post_transforms(self, posttrans): + def _register_post_transforms(self, posttrans: Callable): """ Register the post transforms to the engine, will execute them as a chain when iteration completed. @@ -151,11 +162,9 @@ def _register_post_transforms(self, posttrans): @self.on(Events.ITERATION_COMPLETED) def run_post_transform(engine: Engine) -> None: - if posttrans is None: - raise AssertionError engine.state.output = apply_transform(posttrans, engine.state.output) - def _register_metrics(self, k_metric, add_metrics): + def _register_metrics(self, k_metric: Dict, add_metrics: Optional[Dict] = None): """ Register the key metric and additional metrics to the engine, supports ignite Metrics. @@ -180,7 +189,7 @@ def _compare_metrics(engine: Engine) -> None: engine.state.best_metric = current_val_metric engine.state.best_metric_epoch = engine.state.epoch - def _register_handlers(self, handlers): + def _register_handlers(self, handlers: Sequence): """ Register the handlers to the engine, supports ignite Handlers with `attach` API. diff --git a/tests/test_ensemble_evaluator.py b/tests/test_ensemble_evaluator.py index 9cc977d876..28a2d4f941 100644 --- a/tests/test_ensemble_evaluator.py +++ b/tests/test_ensemble_evaluator.py @@ -12,7 +12,7 @@ import unittest import torch -from ignite.engine import Events +from ignite.engine import EventEnum, Events from monai.engines import EnsembleEvaluator @@ -44,11 +44,17 @@ def forward(self, x): net3 = TestNet(lambda x: x + 4) net4 = TestNet(lambda x: x + 5) + class CustomEvents(EventEnum): + FOO_EVENT = "foo_event" + BAR_EVENT = "bar_event" + val_engine = EnsembleEvaluator( device=device, val_data_loader=val_loader, networks=[net0, net1, net2, net3, net4], pred_keys=["pred0", "pred1", "pred2", "pred3", "pred4"], + event_names=["bwd_event", "opt_event", CustomEvents], + event_to_attr={CustomEvents.FOO_EVENT: "foo", "opt_event": "opt"}, ) @val_engine.on(Events.ITERATION_COMPLETED) @@ -57,6 +63,21 @@ def run_post_transform(engine): expected_value = engine.state.iteration + i torch.testing.assert_allclose(engine.state.output[f"pred{i}"], torch.tensor([[expected_value]])) + @val_engine.on(Events.EPOCH_COMPLETED) + def trigger_custom_event(): + val_engine.fire_event(CustomEvents.FOO_EVENT) + val_engine.fire_event(CustomEvents.BAR_EVENT) + val_engine.fire_event("bwd_event") + val_engine.fire_event("opt_event") + + @val_engine.on(CustomEvents.FOO_EVENT) + def do_foo_op(): + self.assertEqual(val_engine.state.foo, 0) + + @val_engine.on("opt_event") + def do_bar_op(): + self.assertEqual(val_engine.state.opt, 0) + val_engine.run() diff --git a/tests/test_handler_garbage_collector.py b/tests/test_handler_garbage_collector.py index 9f63211a13..c2c5dcbfd6 100644 --- a/tests/test_handler_garbage_collector.py +++ b/tests/test_handler_garbage_collector.py @@ -67,7 +67,8 @@ def _train_func(engine, batch): self.assertGreater(gb_count[0], 0) if iter > 1: # Since we are collecting all objects from all generations manually at each call, - # starting from the second call, there shouldn't be any 1st and 2nd generation objects available to collect. + # starting from the second call, there shouldn't be any 1st and 2nd + # generation objects available to collect. self.assertEqual(gb_count[1], first_count) self.assertEqual(gb_count[2], first_count)