Skip to content
39 changes: 29 additions & 10 deletions monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Callable, Dict, Iterable, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import torch
from torch.utils.data import DataLoader
Expand All @@ -23,11 +23,12 @@
from monai.utils.enums import CommonKeys as Keys

if TYPE_CHECKING:
from ignite.engine import Engine
from ignite.engine import Engine, EventEnum
from ignite.metrics import Metric
else:
Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine")
Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric")
EventEnum, _ = optional_import("ignite.engine", "0.4.4", exact_version, "EventEnum")

__all__ = ["Evaluator", "SupervisedEvaluator", "EnsembleEvaluator"]

Expand Down Expand Up @@ -56,6 +57,10 @@ class Evaluator(Workflow):
amp: whether to enable auto-mixed-precision evaluation, default is False.
mode: model forward mode during evaluation, should be 'eval' or 'train',
which maps to `model.eval()` or `model.train()`, default to 'eval'.
event_names: additional custom ignite events that will register to the engine.
new events can be a list of str or `ignite.engine.events.EventEnum`.
event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160

"""

Expand All @@ -73,6 +78,8 @@ def __init__(
val_handlers: Optional[Sequence] = None,
amp: bool = False,
mode: Union[ForwardMode, str] = ForwardMode.EVAL,
event_names: Optional[List[Union[str, EventEnum]]] = None,
event_to_attr: Optional[dict] = None,
) -> None:
super().__init__(
device=device,
Expand All @@ -87,6 +94,8 @@ def __init__(
additional_metrics=additional_metrics,
handlers=val_handlers,
amp=amp,
event_names=event_names,
event_to_attr=event_to_attr,
)
mode = ForwardMode(mode)
if mode == ForwardMode.EVAL:
Expand Down Expand Up @@ -140,6 +149,10 @@ class SupervisedEvaluator(Evaluator):
amp: whether to enable auto-mixed-precision evaluation, default is False.
mode: model forward mode during evaluation, should be 'eval' or 'train',
which maps to `model.eval()` or `model.train()`, default to 'eval'.
event_names: additional custom ignite events that will register to the engine.
new events can be a list of str or `ignite.engine.events.EventEnum`.
event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160

"""

Expand All @@ -159,6 +172,8 @@ def __init__(
val_handlers: Optional[Sequence] = None,
amp: bool = False,
mode: Union[ForwardMode, str] = ForwardMode.EVAL,
event_names: Optional[List[Union[str, EventEnum]]] = None,
event_to_attr: Optional[dict] = None,
) -> None:
super().__init__(
device=device,
Expand All @@ -173,15 +188,14 @@ def __init__(
val_handlers=val_handlers,
amp=amp,
mode=mode,
# add the iteration events
event_names=[IterationEvents] if event_names is None else event_names + [IterationEvents],
event_to_attr=event_to_attr,
)

self.network = network
self.inferer = SimpleInferer() if inferer is None else inferer

def _register_additional_events(self):
super()._register_additional_events()
self.register_events(*IterationEvents)

def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
Expand Down Expand Up @@ -251,6 +265,10 @@ class EnsembleEvaluator(Evaluator):
amp: whether to enable auto-mixed-precision evaluation, default is False.
mode: model forward mode during evaluation, should be 'eval' or 'train',
which maps to `model.eval()` or `model.train()`, default to 'eval'.
event_names: additional custom ignite events that will register to the engine.
new events can be a list of str or `ignite.engine.events.EventEnum`.
event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160

"""

Expand All @@ -271,6 +289,8 @@ def __init__(
val_handlers: Optional[Sequence] = None,
amp: bool = False,
mode: Union[ForwardMode, str] = ForwardMode.EVAL,
event_names: Optional[List[Union[str, EventEnum]]] = None,
event_to_attr: Optional[dict] = None,
) -> None:
super().__init__(
device=device,
Expand All @@ -285,16 +305,15 @@ def __init__(
val_handlers=val_handlers,
amp=amp,
mode=mode,
# add the iteration events
event_names=[IterationEvents] if event_names is None else event_names + [IterationEvents],
event_to_attr=event_to_attr,
)

self.networks = ensure_tuple(networks)
self.pred_keys = ensure_tuple(pred_keys)
self.inferer = SimpleInferer() if inferer is None else inferer

def _register_additional_events(self):
super()._register_additional_events()
self.register_events(*IterationEvents)

def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
Expand Down
19 changes: 12 additions & 7 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Callable, Dict, Iterable, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import torch
from torch.optim.optimizer import Optimizer
Expand All @@ -23,11 +23,12 @@
from monai.utils.enums import CommonKeys as Keys

if TYPE_CHECKING:
from ignite.engine import Engine
from ignite.engine import Engine, EventEnum
from ignite.metrics import Metric
else:
Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine")
Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric")
EventEnum, _ = optional_import("ignite.engine", "0.4.4", exact_version, "EventEnum")

__all__ = ["Trainer", "SupervisedTrainer", "GanTrainer"]

Expand Down Expand Up @@ -78,6 +79,10 @@ class SupervisedTrainer(Trainer):
train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
CheckpointHandler, StatsHandler, SegmentationSaver, etc.
amp: whether to enable auto-mixed-precision training, default is False.
event_names: additional custom ignite events that will register to the engine.
new events can be a list of str or `ignite.engine.events.EventEnum`.
event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160

"""

Expand All @@ -99,8 +104,9 @@ def __init__(
additional_metrics: Optional[Dict[str, Metric]] = None,
train_handlers: Optional[Sequence] = None,
amp: bool = False,
event_names: Optional[List[Union[str, EventEnum]]] = None,
event_to_attr: Optional[dict] = None,
) -> None:
# set up Ignite engine and environments
super().__init__(
device=device,
max_epochs=max_epochs,
Expand All @@ -114,17 +120,16 @@ def __init__(
additional_metrics=additional_metrics,
handlers=train_handlers,
amp=amp,
# add the iteration events
event_names=[IterationEvents] if event_names is None else event_names + [IterationEvents],
event_to_attr=event_to_attr,
)

self.network = network
self.optimizer = optimizer
self.loss_function = loss_function
self.inferer = SimpleInferer() if inferer is None else inferer

def _register_additional_events(self):
super()._register_additional_events()
self.register_events(*IterationEvents)

def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
"""
Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine.
Expand Down
39 changes: 24 additions & 15 deletions monai/engines/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Callable, Dict, Iterable, Optional, Sequence, Union
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Union

import torch
import torch.distributed as dist
Expand All @@ -23,12 +23,14 @@
IgniteEngine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine")
State, _ = optional_import("ignite.engine", "0.4.4", exact_version, "State")
Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")

if TYPE_CHECKING:
from ignite.engine import Engine
from ignite.engine import Engine, EventEnum
from ignite.metrics import Metric
else:
Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine")
Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric")
EventEnum, _ = optional_import("ignite.engine", "0.4.4", exact_version, "EventEnum")


class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optional_import
Expand Down Expand Up @@ -60,6 +62,10 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona
handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
CheckpointHandler, StatsHandler, SegmentationSaver, etc.
amp: whether to enable auto-mixed-precision training or inference, default is False.
event_names: additional custom ignite events that will register to the engine.
new events can be a list of str or `ignite.engine.events.EventEnum`.
event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160

Raises:
TypeError: When ``device`` is not a ``torch.Device``.
Expand All @@ -83,6 +89,8 @@ def __init__(
additional_metrics: Optional[Dict[str, Metric]] = None,
handlers: Optional[Sequence] = None,
amp: bool = False,
event_names: Optional[List[Union[str, EventEnum]]] = None,
event_to_attr: Optional[dict] = None,
) -> None:
if iteration_update is not None:
super().__init__(iteration_update)
Expand Down Expand Up @@ -128,34 +136,35 @@ def set_sampler_epoch(engine: Engine):
self.prepare_batch = prepare_batch
self.amp = amp

self._register_additional_events()
if event_names is not None:
if not isinstance(event_names, list):
raise ValueError("event_names must be a list or string or EventEnum.")
for name in event_names:
if isinstance(name, str):
self.register_events(name, event_to_attr=event_to_attr)
elif issubclass(name, EventEnum):
self.register_events(*name, event_to_attr=event_to_attr)
else:
raise ValueError("event_names must be a list or string or EventEnum.")

if post_transform is not None:
self._register_post_transforms(post_transform)
if key_metric is not None:
self._register_metrics(key_metric, additional_metrics)
if handlers is not None:
self._register_handlers(handlers)

def _register_additional_events(self):
"""
Register more ignite Events to the engine.

"""
pass

def _register_post_transforms(self, posttrans):
def _register_post_transforms(self, posttrans: Callable):
"""
Register the post transforms to the engine, will execute them as a chain when iteration completed.

"""

@self.on(Events.ITERATION_COMPLETED)
def run_post_transform(engine: Engine) -> None:
if posttrans is None:
raise AssertionError
engine.state.output = apply_transform(posttrans, engine.state.output)

def _register_metrics(self, k_metric, add_metrics):
def _register_metrics(self, k_metric: Dict, add_metrics: Optional[Dict] = None):
"""
Register the key metric and additional metrics to the engine, supports ignite Metrics.

Expand All @@ -180,7 +189,7 @@ def _compare_metrics(engine: Engine) -> None:
engine.state.best_metric = current_val_metric
engine.state.best_metric_epoch = engine.state.epoch

def _register_handlers(self, handlers):
def _register_handlers(self, handlers: Sequence):
"""
Register the handlers to the engine, supports ignite Handlers with `attach` API.

Expand Down
23 changes: 22 additions & 1 deletion tests/test_ensemble_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import unittest

import torch
from ignite.engine import Events
from ignite.engine import EventEnum, Events

from monai.engines import EnsembleEvaluator

Expand Down Expand Up @@ -44,11 +44,17 @@ def forward(self, x):
net3 = TestNet(lambda x: x + 4)
net4 = TestNet(lambda x: x + 5)

class CustomEvents(EventEnum):
FOO_EVENT = "foo_event"
BAR_EVENT = "bar_event"

val_engine = EnsembleEvaluator(
device=device,
val_data_loader=val_loader,
networks=[net0, net1, net2, net3, net4],
pred_keys=["pred0", "pred1", "pred2", "pred3", "pred4"],
event_names=["bwd_event", "opt_event", CustomEvents],
event_to_attr={CustomEvents.FOO_EVENT: "foo", "opt_event": "opt"},
)

@val_engine.on(Events.ITERATION_COMPLETED)
Expand All @@ -57,6 +63,21 @@ def run_post_transform(engine):
expected_value = engine.state.iteration + i
torch.testing.assert_allclose(engine.state.output[f"pred{i}"], torch.tensor([[expected_value]]))

@val_engine.on(Events.EPOCH_COMPLETED)
def trigger_custom_event():
val_engine.fire_event(CustomEvents.FOO_EVENT)
val_engine.fire_event(CustomEvents.BAR_EVENT)
val_engine.fire_event("bwd_event")
val_engine.fire_event("opt_event")

@val_engine.on(CustomEvents.FOO_EVENT)
def do_foo_op():
self.assertEqual(val_engine.state.foo, 0)

@val_engine.on("opt_event")
def do_bar_op():
self.assertEqual(val_engine.state.opt, 0)

val_engine.run()


Expand Down
3 changes: 2 additions & 1 deletion tests/test_handler_garbage_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def _train_func(engine, batch):
self.assertGreater(gb_count[0], 0)
if iter > 1:
# Since we are collecting all objects from all generations manually at each call,
# starting from the second call, there shouldn't be any 1st and 2nd generation objects available to collect.
# starting from the second call, there shouldn't be any 1st and 2nd
# generation objects available to collect.
self.assertEqual(gb_count[1], first_count)
self.assertEqual(gb_count[2], first_count)

Expand Down