diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index 7b926c8469..8256680735 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -12,4 +12,4 @@ from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator from .multi_gpu_supervised_trainer import create_multigpu_supervised_evaluator, create_multigpu_supervised_trainer from .trainer import GanTrainer, SupervisedTrainer, Trainer -from .utils import CommonKeys, GanKeys, default_make_latent, default_prepare_batch, get_devices_spec +from .utils import CommonKeys, GanKeys, IterationEvents, default_make_latent, default_prepare_batch, get_devices_spec diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index e0ca59558e..0b7167fb3a 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -15,7 +15,7 @@ from torch.utils.data import DataLoader from monai.engines.utils import CommonKeys as Keys -from monai.engines.utils import default_prepare_batch +from monai.engines.utils import IterationEvents, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.networks.utils import eval_mode @@ -164,6 +164,10 @@ def __init__( self.network = network self.inferer = SimpleInferer() if inferer is None else inferer + def _register_additional_events(self): + super()._register_additional_events() + self.register_events(*IterationEvents) + def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. @@ -190,15 +194,18 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict else: inputs, targets, args, kwargs = batch + # put iteration outputs into engine.state + engine.state.output = output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # execute forward computation with eval_mode(self.network): if self.amp: with torch.cuda.amp.autocast(): - predictions = self.inferer(inputs, self.network, *args, **kwargs) + output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) else: - predictions = self.inferer(inputs, self.network, *args, **kwargs) + output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) + engine.fire_event(IterationEvents.FORWARD_COMPLETED) - return {Keys.IMAGE: inputs, Keys.LABEL: targets, Keys.PRED: predictions} + return output class EnsembleEvaluator(Evaluator): @@ -266,6 +273,10 @@ def __init__( self.pred_keys = ensure_tuple(pred_keys) self.inferer = SimpleInferer() if inferer is None else inferer + def _register_additional_events(self): + super()._register_additional_events() + self.register_events(*IterationEvents) + def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. @@ -295,14 +306,15 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict else: inputs, targets, args, kwargs = batch - # execute forward computation - predictions = {Keys.IMAGE: inputs, Keys.LABEL: targets} + # put iteration outputs into engine.state + engine.state.output = output = {Keys.IMAGE: inputs, Keys.LABEL: targets} for idx, network in enumerate(self.networks): with eval_mode(network): if self.amp: with torch.cuda.amp.autocast(): - predictions.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) + output.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) else: - predictions.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) + output.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) + engine.fire_event(IterationEvents.FORWARD_COMPLETED) - return predictions + return output diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 5d4f82b0af..efb2ab12fa 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader from monai.engines.utils import CommonKeys as Keys -from monai.engines.utils import GanKeys, default_make_latent, default_prepare_batch +from monai.engines.utils import GanKeys, IterationEvents, default_make_latent, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform @@ -121,6 +121,10 @@ def __init__( self.loss_function = loss_function self.inferer = SimpleInferer() if inferer is None else inferer + def _register_additional_events(self): + super()._register_additional_events() + self.register_events(*IterationEvents) + def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine. @@ -147,23 +151,32 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): kwargs: Dict = {} else: inputs, targets, args, kwargs = batch + # put iteration outputs into engine.state + engine.state.output = output = {Keys.IMAGE: inputs, Keys.LABEL: targets} + + def _compute_pred_loss(): + output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) + engine.fire_event(IterationEvents.FORWARD_COMPLETED) + output[Keys.LOSS] = self.loss_function(output[Keys.PRED], targets).mean() + engine.fire_event(IterationEvents.LOSS_COMPLETED) self.network.train() self.optimizer.zero_grad() if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): - predictions = self.inferer(inputs, self.network, *args, **kwargs) - loss = self.loss_function(predictions, targets).mean() - self.scaler.scale(loss).backward() + _compute_pred_loss() + self.scaler.scale(output[Keys.LOSS]).backward() + engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.scaler.step(self.optimizer) self.scaler.update() else: - predictions = self.inferer(inputs, self.network, *args, **kwargs) - loss = self.loss_function(predictions, targets).mean() - loss.backward() + _compute_pred_loss() + output[Keys.LOSS].backward() + engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.optimizer.step() + engine.fire_event(IterationEvents.OPTIMIZER_COMPLETED) - return {Keys.IMAGE: inputs, Keys.LABEL: targets, Keys.PRED: predictions, Keys.LOSS: loss.item()} + return output class GanTrainer(Trainer): diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 7a2dc40b8d..f603338097 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -9,11 +9,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union import torch -__all__ = ["CommonKeys", "GanKeys", "get_devices_spec", "default_prepare_batch", "default_make_latent"] +from monai.utils import exact_version, optional_import + +if TYPE_CHECKING: + from ignite.engine import EventEnum +else: + EventEnum, _ = optional_import("ignite.engine", "0.4.2", exact_version, "EventEnum") + +__all__ = [ + "IterationEvents", + "CommonKeys", + "GanKeys", + "get_devices_spec", + "default_prepare_batch", + "default_make_latent", +] + + +class IterationEvents(EventEnum): + """ + Addtional Events engine can register and trigger in the iteration process. + Refer to the example in ignite: https://github.com/pytorch/ignite/blob/master/ignite/engine/events.py#L146 + These Events can be triggered during training iteration: + `FORWARD_COMPLETED` is the Event when `network(image, label)` completed. + `LOSS_COMPLETED` is the Event when `loss(pred, label)` completed. + `BACKWARD_COMPLETED` is the Event when `loss.backward()` completed. + + """ + + FORWARD_COMPLETED = "forward_completed" + LOSS_COMPLETED = "loss_completed" + BACKWARD_COMPLETED = "backward_completed" + OPTIMIZER_COMPLETED = "optimizer_completed" class CommonKeys: @@ -36,6 +67,7 @@ class CommonKeys: class GanKeys: """ A set of common keys for generative adversarial networks. + """ REALS = "reals" diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 1d8c74c4bb..67fdacad4a 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -119,44 +119,68 @@ def set_sampler_epoch(engine: Engine): self.data_loader = data_loader self.non_blocking = non_blocking self.prepare_batch = prepare_batch + self.amp = amp + self._register_additional_events() if post_transform is not None: + self._register_post_transforms(post_transform) + if key_metric is not None: + self._register_metrics(key_metric, additional_metrics) + if handlers is not None: + self._register_handlers(handlers) - @self.on(Events.ITERATION_COMPLETED) - def run_post_transform(engine: Engine) -> None: - if post_transform is None: - raise AssertionError - engine.state.output = apply_transform(post_transform, engine.state.output) + def _register_additional_events(self): + """ + Register more ignite Events to the engine. - if key_metric is not None: + """ + pass - if not isinstance(key_metric, dict): - raise TypeError(f"key_metric must be None or a dict but is {type(key_metric).__name__}.") - self.state.key_metric_name = list(key_metric.keys())[0] - metrics = key_metric - if additional_metrics is not None and len(additional_metrics) > 0: - if not isinstance(additional_metrics, dict): - raise TypeError( - f"additional_metrics must be None or a dict but is {type(additional_metrics).__name__}." - ) - metrics.update(additional_metrics) - for name, metric in metrics.items(): - metric.attach(self, name) - - @self.on(Events.EPOCH_COMPLETED) - def _compare_metrics(engine: Engine) -> None: - if engine.state.key_metric_name is not None: - current_val_metric = engine.state.metrics[engine.state.key_metric_name] - if current_val_metric > engine.state.best_metric: - self.logger.info(f"Got new best metric of {engine.state.key_metric_name}: {current_val_metric}") - engine.state.best_metric = current_val_metric - engine.state.best_metric_epoch = engine.state.epoch + def _register_post_transforms(self, posttrans): + """ + Register the post transforms to the engine, will execute them as a chain when iteration completed. - if handlers is not None: - handlers_ = ensure_tuple(handlers) - for handler in handlers_: - handler.attach(self) - self.amp = amp + """ + + @self.on(Events.ITERATION_COMPLETED) + def run_post_transform(engine: Engine) -> None: + if posttrans is None: + raise AssertionError + engine.state.output = apply_transform(posttrans, engine.state.output) + + def _register_metrics(self, k_metric, add_metrics): + """ + Register the key metric and additional metrics to the engine, supports ignite Metrics. + + """ + if not isinstance(k_metric, dict): + raise TypeError(f"key_metric must be None or a dict but is {type(k_metric).__name__}.") + self.state.key_metric_name = list(k_metric.keys())[0] + metrics = k_metric + if add_metrics is not None and len(add_metrics) > 0: + if not isinstance(add_metrics, dict): + raise TypeError(f"additional metrics must be None or a dict but is {type(add_metrics).__name__}.") + metrics.update(add_metrics) + for name, metric in metrics.items(): + metric.attach(self, name) + + @self.on(Events.EPOCH_COMPLETED) + def _compare_metrics(engine: Engine) -> None: + if engine.state.key_metric_name is not None: + current_val_metric = engine.state.metrics[engine.state.key_metric_name] + if current_val_metric > engine.state.best_metric: + self.logger.info(f"Got new best metric of {engine.state.key_metric_name}: {current_val_metric}") + engine.state.best_metric = current_val_metric + engine.state.best_metric_epoch = engine.state.epoch + + def _register_handlers(self, handlers): + """ + Register the handlers to the engine, supports ignite Handlers with `attach` API. + + """ + handlers_ = ensure_tuple(handlers) + for handler in handlers_: + handler.attach(self) def run(self) -> None: """ diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 124224ec3f..aa4ccbb76d 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -25,7 +25,7 @@ import monai from monai.data import create_test_image_3d -from monai.engines import SupervisedEvaluator, SupervisedTrainer +from monai.engines import IterationEvents, SupervisedEvaluator, SupervisedTrainer from monai.handlers import ( CheckpointLoader, CheckpointSaver, @@ -113,6 +113,14 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) + + class _TestEvalIterEvents: + def attach(self, engine): + engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed) + + def _forward_completed(self, engine): + pass + val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=root_dir, output_transform=lambda x: None), @@ -120,6 +128,7 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): log_dir=root_dir, batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"] ), CheckpointSaver(save_dir=root_dir, save_dict={"net": net}, save_key_metric=True), + _TestEvalIterEvents(), ] evaluator = SupervisedEvaluator( @@ -143,12 +152,33 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) + + class _TestTrainIterEvents: + def attach(self, engine): + engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed) + engine.add_event_handler(IterationEvents.LOSS_COMPLETED, self._loss_completed) + engine.add_event_handler(IterationEvents.BACKWARD_COMPLETED, self._backward_completed) + engine.add_event_handler(IterationEvents.OPTIMIZER_COMPLETED, self._optimizer_completed) + + def _forward_completed(self, engine): + pass + + def _loss_completed(self, engine): + pass + + def _backward_completed(self, engine): + pass + + def _optimizer_completed(self, engine): + pass + train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=2, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(log_dir=root_dir, tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir=root_dir, save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True), + _TestTrainIterEvents(), ] trainer = SupervisedTrainer(