From 20010f023df3c8090d95e92e3bd469b1f340792a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 21 Jan 2021 11:45:27 +0800 Subject: [PATCH 1/7] [DLMED] add more Events Signed-off-by: Nic Ma --- monai/engines/utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 7a2dc40b8d..82237c134c 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -16,6 +16,12 @@ __all__ = ["CommonKeys", "GanKeys", "get_devices_spec", "default_prepare_batch", "default_make_latent"] +class IterationEvents: + """ + Addtional Events engine can register and trigger in the iteration process. + """ + + class CommonKeys: """ A set of common keys for dictionary based supervised training process. From 25623cb8dfe49bfa3bab0754b5418f69905e9f94 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 21 Jan 2021 16:55:36 +0800 Subject: [PATCH 2/7] [DLMED] add 3 Events Signed-off-by: Nic Ma --- monai/engines/__init__.py | 2 +- monai/engines/trainer.py | 12 +++++- monai/engines/utils.py | 32 ++++++++++++++-- monai/engines/workflow.py | 77 ++++++++++++++++++++++----------------- 4 files changed, 83 insertions(+), 40 deletions(-) diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index 7b926c8469..cf1e544729 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -12,4 +12,4 @@ from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator from .multi_gpu_supervised_trainer import create_multigpu_supervised_evaluator, create_multigpu_supervised_trainer from .trainer import GanTrainer, SupervisedTrainer, Trainer -from .utils import CommonKeys, GanKeys, default_make_latent, default_prepare_batch, get_devices_spec +from .utils import IterationEvents, CommonKeys, GanKeys, default_make_latent, default_prepare_batch, get_devices_spec diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 5d4f82b0af..a825a884c9 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader from monai.engines.utils import CommonKeys as Keys -from monai.engines.utils import GanKeys, default_make_latent, default_prepare_batch +from monai.engines.utils import IterationEvents, GanKeys, default_make_latent, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform @@ -121,6 +121,10 @@ def __init__( self.loss_function = loss_function self.inferer = SimpleInferer() if inferer is None else inferer + def _register_additional_events(self): + super()._register_additional_events() + self.register_events(*IterationEvents) + def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine. @@ -153,14 +157,20 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): predictions = self.inferer(inputs, self.network, *args, **kwargs) + engine.fire_event(IterationEvents.PREDICT_COMPLETED) loss = self.loss_function(predictions, targets).mean() + engine.fire_event(IterationEvents.LOSS_COMPLETED) self.scaler.scale(loss).backward() + engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.scaler.step(self.optimizer) self.scaler.update() else: predictions = self.inferer(inputs, self.network, *args, **kwargs) + engine.fire_event(IterationEvents.PREDICT_COMPLETED) loss = self.loss_function(predictions, targets).mean() + engine.fire_event(IterationEvents.LOSS_COMPLETED) loss.backward() + engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.optimizer.step() return {Keys.IMAGE: inputs, Keys.LABEL: targets, Keys.PRED: predictions, Keys.LOSS: loss.item()} diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 82237c134c..1c773ca26d 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -9,18 +9,41 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Sequence, Tuple, Union - +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union import torch -__all__ = ["CommonKeys", "GanKeys", "get_devices_spec", "default_prepare_batch", "default_make_latent"] +from monai.utils import exact_version, optional_import + +if TYPE_CHECKING: + from ignite.engine import EventEnum +else: + EventEnum, _ = optional_import("ignite.engine", "0.4.2", exact_version, "EventEnum") + +__all__ = [ + "IterationEvents", + "CommonKeys", + "GanKeys", + "get_devices_spec", + "default_prepare_batch", + "default_make_latent", +] -class IterationEvents: +class IterationEvents(EventEnum): """ Addtional Events engine can register and trigger in the iteration process. + Refer to the example in ignite: https://github.com/pytorch/ignite/blob/master/ignite/engine/events.py#L146 + These Events can be triggered during training iteration: + `PREDICT_COMPLETED` is the Event when `network(image, label)` completed. + `LOSS_COMPLETED` is the Event when `loss(pred, label)` completed. + `BACKWARD_COMPLETED` is the Event when `loss.backward()` completed. + """ + PREDICT_COMPLETED = "predict_completed" + LOSS_COMPLETED = "loss_completed" + BACKWARD_COMPLETED = 'backward_completed' + class CommonKeys: """ @@ -42,6 +65,7 @@ class CommonKeys: class GanKeys: """ A set of common keys for generative adversarial networks. + """ REALS = "reals" diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 1d8c74c4bb..8a7b204cd8 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -119,44 +119,53 @@ def set_sampler_epoch(engine: Engine): self.data_loader = data_loader self.non_blocking = non_blocking self.prepare_batch = prepare_batch + self.amp = amp + self._register_additional_events() if post_transform is not None: - - @self.on(Events.ITERATION_COMPLETED) - def run_post_transform(engine: Engine) -> None: - if post_transform is None: - raise AssertionError - engine.state.output = apply_transform(post_transform, engine.state.output) - + self._register_post_transforms(post_transform) if key_metric is not None: - - if not isinstance(key_metric, dict): - raise TypeError(f"key_metric must be None or a dict but is {type(key_metric).__name__}.") - self.state.key_metric_name = list(key_metric.keys())[0] - metrics = key_metric - if additional_metrics is not None and len(additional_metrics) > 0: - if not isinstance(additional_metrics, dict): - raise TypeError( - f"additional_metrics must be None or a dict but is {type(additional_metrics).__name__}." - ) - metrics.update(additional_metrics) - for name, metric in metrics.items(): - metric.attach(self, name) - - @self.on(Events.EPOCH_COMPLETED) - def _compare_metrics(engine: Engine) -> None: - if engine.state.key_metric_name is not None: - current_val_metric = engine.state.metrics[engine.state.key_metric_name] - if current_val_metric > engine.state.best_metric: - self.logger.info(f"Got new best metric of {engine.state.key_metric_name}: {current_val_metric}") - engine.state.best_metric = current_val_metric - engine.state.best_metric_epoch = engine.state.epoch - + self._register_metrics(key_metric, additional_metrics) if handlers is not None: - handlers_ = ensure_tuple(handlers) - for handler in handlers_: - handler.attach(self) - self.amp = amp + self._register_handlers(handlers) + + def _register_additional_events(self): + pass + + def _register_post_transforms(self, posttrans): + @self.on(Events.ITERATION_COMPLETED) + def run_post_transform(engine: Engine) -> None: + if posttrans is None: + raise AssertionError + engine.state.output = apply_transform(posttrans, engine.state.output) + + def _register_metrics(self, k_metric, add_metrics): + if not isinstance(k_metric, dict): + raise TypeError(f"key_metric must be None or a dict but is {type(key_metric).__name__}.") + self.state.key_metric_name = list(k_metric.keys())[0] + metrics = k_metric + if add_metrics is not None and len(add_metrics) > 0: + if not isinstance(add_metrics, dict): + raise TypeError( + f"additional_metrics must be None or a dict but is {type(additional_metrics).__name__}." + ) + metrics.update(add_metrics) + for name, metric in metrics.items(): + metric.attach(self, name) + + def _register_handlers(self, handlers): + handlers_ = ensure_tuple(handlers) + for handler in handlers_: + handler.attach(self) + + @self.on(Events.EPOCH_COMPLETED) + def _compare_metrics(engine: Engine) -> None: + if engine.state.key_metric_name is not None: + current_val_metric = engine.state.metrics[engine.state.key_metric_name] + if current_val_metric > engine.state.best_metric: + self.logger.info(f"Got new best metric of {engine.state.key_metric_name}: {current_val_metric}") + engine.state.best_metric = current_val_metric + engine.state.best_metric_epoch = engine.state.epoch def run(self) -> None: """ From ea118df7b009e7a0fe41369399e3198bbf1561e9 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Thu, 21 Jan 2021 09:42:23 +0000 Subject: [PATCH 3/7] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/engines/__init__.py | 2 +- monai/engines/trainer.py | 2 +- monai/engines/utils.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index cf1e544729..8256680735 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -12,4 +12,4 @@ from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator from .multi_gpu_supervised_trainer import create_multigpu_supervised_evaluator, create_multigpu_supervised_trainer from .trainer import GanTrainer, SupervisedTrainer, Trainer -from .utils import IterationEvents, CommonKeys, GanKeys, default_make_latent, default_prepare_batch, get_devices_spec +from .utils import CommonKeys, GanKeys, IterationEvents, default_make_latent, default_prepare_batch, get_devices_spec diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index a825a884c9..8a1e33b3ca 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader from monai.engines.utils import CommonKeys as Keys -from monai.engines.utils import IterationEvents, GanKeys, default_make_latent, default_prepare_batch +from monai.engines.utils import GanKeys, IterationEvents, default_make_latent, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 1c773ca26d..00f8cbb27b 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -10,6 +10,7 @@ # limitations under the License. from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union + import torch from monai.utils import exact_version, optional_import @@ -42,7 +43,7 @@ class IterationEvents(EventEnum): PREDICT_COMPLETED = "predict_completed" LOSS_COMPLETED = "loss_completed" - BACKWARD_COMPLETED = 'backward_completed' + BACKWARD_COMPLETED = "backward_completed" class CommonKeys: From d73735ce300e3406f2305ac39af2b95c94916083 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 21 Jan 2021 17:55:25 +0800 Subject: [PATCH 4/7] [DLMED] add tests Signed-off-by: Nic Ma --- tests/test_integration_workflows.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 124224ec3f..88d19eab72 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -25,7 +25,7 @@ import monai from monai.data import create_test_image_3d -from monai.engines import SupervisedEvaluator, SupervisedTrainer +from monai.engines import SupervisedEvaluator, SupervisedTrainer, IterationEvents from monai.handlers import ( CheckpointLoader, CheckpointSaver, @@ -143,12 +143,29 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) + + class _TestAdditionalEvents: + def attach(self, engine): + engine.add_event_handler(IterationEvents.PREDICT_COMPLETED, self._predict_completed) + engine.add_event_handler(IterationEvents.LOSS_COMPLETED, self._loss_completed) + engine.add_event_handler(IterationEvents.BACKWARD_COMPLETED, self._backward_completed) + + def _predict_completed(self, engine): + pass + + def _loss_completed(self, engine): + pass + + def _backward_completed(self, engine): + pass + train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=2, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(log_dir=root_dir, tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir=root_dir, save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True), + _TestAdditionalEvents(), ] trainer = SupervisedTrainer( From d5a25cfdba4a1c0817efb7ea1ea329c7a9a67e75 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Thu, 21 Jan 2021 10:16:29 +0000 Subject: [PATCH 5/7] [MONAI] python code formatting Signed-off-by: monai-bot --- tests/test_integration_workflows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 88d19eab72..45e024e401 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -25,7 +25,7 @@ import monai from monai.data import create_test_image_3d -from monai.engines import SupervisedEvaluator, SupervisedTrainer, IterationEvents +from monai.engines import IterationEvents, SupervisedEvaluator, SupervisedTrainer from monai.handlers import ( CheckpointLoader, CheckpointSaver, From 0fb4353af28f48c992d52b1b1a229d1b4097bf4e Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 22 Jan 2021 11:11:39 +0800 Subject: [PATCH 6/7] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/engines/evaluator.py | 30 +++++++++++++++++++-------- monai/engines/trainer.py | 25 ++++++++++++---------- monai/engines/utils.py | 5 +++-- monai/engines/workflow.py | 32 ++++++++++++++++++++++------- tests/test_integration_workflows.py | 21 +++++++++++++++---- 5 files changed, 80 insertions(+), 33 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index e0ca59558e..34e7012263 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -15,7 +15,7 @@ from torch.utils.data import DataLoader from monai.engines.utils import CommonKeys as Keys -from monai.engines.utils import default_prepare_batch +from monai.engines.utils import default_prepare_batch, IterationEvents from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.networks.utils import eval_mode @@ -164,6 +164,10 @@ def __init__( self.network = network self.inferer = SimpleInferer() if inferer is None else inferer + def _register_additional_events(self): + super()._register_additional_events() + self.register_events(*IterationEvents) + def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. @@ -190,15 +194,18 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict else: inputs, targets, args, kwargs = batch + # put iteration outputs into engine.state + engine.state.output = output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # execute forward computation with eval_mode(self.network): if self.amp: with torch.cuda.amp.autocast(): - predictions = self.inferer(inputs, self.network, *args, **kwargs) + output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) else: - predictions = self.inferer(inputs, self.network, *args, **kwargs) + output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) + engine.fire_event(IterationEvents.FORWARD_COMPLETED) - return {Keys.IMAGE: inputs, Keys.LABEL: targets, Keys.PRED: predictions} + return output class EnsembleEvaluator(Evaluator): @@ -266,6 +273,10 @@ def __init__( self.pred_keys = ensure_tuple(pred_keys) self.inferer = SimpleInferer() if inferer is None else inferer + def _register_additional_events(self): + super()._register_additional_events() + self.register_events(*IterationEvents) + def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. @@ -295,14 +306,15 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict else: inputs, targets, args, kwargs = batch - # execute forward computation - predictions = {Keys.IMAGE: inputs, Keys.LABEL: targets} + # put iteration outputs into engine.state + engine.state.output = output = {Keys.IMAGE: inputs, Keys.LABEL: targets} for idx, network in enumerate(self.networks): with eval_mode(network): if self.amp: with torch.cuda.amp.autocast(): - predictions.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) + output.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) else: - predictions.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) + output.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) + engine.fire_event(IterationEvents.FORWARD_COMPLETED) - return predictions + return output diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 8a1e33b3ca..efb2ab12fa 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -151,29 +151,32 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): kwargs: Dict = {} else: inputs, targets, args, kwargs = batch + # put iteration outputs into engine.state + engine.state.output = output = {Keys.IMAGE: inputs, Keys.LABEL: targets} + + def _compute_pred_loss(): + output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) + engine.fire_event(IterationEvents.FORWARD_COMPLETED) + output[Keys.LOSS] = self.loss_function(output[Keys.PRED], targets).mean() + engine.fire_event(IterationEvents.LOSS_COMPLETED) self.network.train() self.optimizer.zero_grad() if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): - predictions = self.inferer(inputs, self.network, *args, **kwargs) - engine.fire_event(IterationEvents.PREDICT_COMPLETED) - loss = self.loss_function(predictions, targets).mean() - engine.fire_event(IterationEvents.LOSS_COMPLETED) - self.scaler.scale(loss).backward() + _compute_pred_loss() + self.scaler.scale(output[Keys.LOSS]).backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.scaler.step(self.optimizer) self.scaler.update() else: - predictions = self.inferer(inputs, self.network, *args, **kwargs) - engine.fire_event(IterationEvents.PREDICT_COMPLETED) - loss = self.loss_function(predictions, targets).mean() - engine.fire_event(IterationEvents.LOSS_COMPLETED) - loss.backward() + _compute_pred_loss() + output[Keys.LOSS].backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.optimizer.step() + engine.fire_event(IterationEvents.OPTIMIZER_COMPLETED) - return {Keys.IMAGE: inputs, Keys.LABEL: targets, Keys.PRED: predictions, Keys.LOSS: loss.item()} + return output class GanTrainer(Trainer): diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 00f8cbb27b..f603338097 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -35,15 +35,16 @@ class IterationEvents(EventEnum): Addtional Events engine can register and trigger in the iteration process. Refer to the example in ignite: https://github.com/pytorch/ignite/blob/master/ignite/engine/events.py#L146 These Events can be triggered during training iteration: - `PREDICT_COMPLETED` is the Event when `network(image, label)` completed. + `FORWARD_COMPLETED` is the Event when `network(image, label)` completed. `LOSS_COMPLETED` is the Event when `loss(pred, label)` completed. `BACKWARD_COMPLETED` is the Event when `loss.backward()` completed. """ - PREDICT_COMPLETED = "predict_completed" + FORWARD_COMPLETED = "forward_completed" LOSS_COMPLETED = "loss_completed" BACKWARD_COMPLETED = "backward_completed" + OPTIMIZER_COMPLETED = "optimizer_completed" class CommonKeys: diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 8a7b204cd8..1201559973 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -130,9 +130,17 @@ def set_sampler_epoch(engine: Engine): self._register_handlers(handlers) def _register_additional_events(self): + """ + Register more ignite Events to the engine. + + """ pass def _register_post_transforms(self, posttrans): + """ + Register the post transforms to the engine, will execute them as a chain when iteration completed. + + """ @self.on(Events.ITERATION_COMPLETED) def run_post_transform(engine: Engine) -> None: if posttrans is None: @@ -140,24 +148,23 @@ def run_post_transform(engine: Engine) -> None: engine.state.output = apply_transform(posttrans, engine.state.output) def _register_metrics(self, k_metric, add_metrics): + """ + Register the key metric and additional metrics to the engine, supports ignite Metrics. + + """ if not isinstance(k_metric, dict): - raise TypeError(f"key_metric must be None or a dict but is {type(key_metric).__name__}.") + raise TypeError(f"key_metric must be None or a dict but is {type(k_metric).__name__}.") self.state.key_metric_name = list(k_metric.keys())[0] metrics = k_metric if add_metrics is not None and len(add_metrics) > 0: if not isinstance(add_metrics, dict): raise TypeError( - f"additional_metrics must be None or a dict but is {type(additional_metrics).__name__}." + f"additional metrics must be None or a dict but is {type(add_metrics).__name__}." ) metrics.update(add_metrics) for name, metric in metrics.items(): metric.attach(self, name) - def _register_handlers(self, handlers): - handlers_ = ensure_tuple(handlers) - for handler in handlers_: - handler.attach(self) - @self.on(Events.EPOCH_COMPLETED) def _compare_metrics(engine: Engine) -> None: if engine.state.key_metric_name is not None: @@ -167,6 +174,17 @@ def _compare_metrics(engine: Engine) -> None: engine.state.best_metric = current_val_metric engine.state.best_metric_epoch = engine.state.epoch + def _register_handlers(self, handlers): + """ + Register the handlers to the engine, supports ignite Handlers with `attach` API. + + """ + handlers_ = ensure_tuple(handlers) + for handler in handlers_: + handler.attach(self) + + + def run(self) -> None: """ Execute training, validation or evaluation based on Ignite Engine. diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 45e024e401..aa4ccbb76d 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -113,6 +113,14 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) + + class _TestEvalIterEvents: + def attach(self, engine): + engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed) + + def _forward_completed(self, engine): + pass + val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=root_dir, output_transform=lambda x: None), @@ -120,6 +128,7 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): log_dir=root_dir, batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"] ), CheckpointSaver(save_dir=root_dir, save_dict={"net": net}, save_key_metric=True), + _TestEvalIterEvents(), ] evaluator = SupervisedEvaluator( @@ -144,13 +153,14 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): ] ) - class _TestAdditionalEvents: + class _TestTrainIterEvents: def attach(self, engine): - engine.add_event_handler(IterationEvents.PREDICT_COMPLETED, self._predict_completed) + engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed) engine.add_event_handler(IterationEvents.LOSS_COMPLETED, self._loss_completed) engine.add_event_handler(IterationEvents.BACKWARD_COMPLETED, self._backward_completed) + engine.add_event_handler(IterationEvents.OPTIMIZER_COMPLETED, self._optimizer_completed) - def _predict_completed(self, engine): + def _forward_completed(self, engine): pass def _loss_completed(self, engine): @@ -159,13 +169,16 @@ def _loss_completed(self, engine): def _backward_completed(self, engine): pass + def _optimizer_completed(self, engine): + pass + train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=2, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(log_dir=root_dir, tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir=root_dir, save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True), - _TestAdditionalEvents(), + _TestTrainIterEvents(), ] trainer = SupervisedTrainer( From 2be4ef1924cbd4f5b0368a512c90534b7c56be0d Mon Sep 17 00:00:00 2001 From: monai-bot Date: Fri, 22 Jan 2021 03:26:26 +0000 Subject: [PATCH 7/7] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/engines/evaluator.py | 2 +- monai/engines/workflow.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 34e7012263..0b7167fb3a 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -15,7 +15,7 @@ from torch.utils.data import DataLoader from monai.engines.utils import CommonKeys as Keys -from monai.engines.utils import default_prepare_batch, IterationEvents +from monai.engines.utils import IterationEvents, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.networks.utils import eval_mode diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 1201559973..67fdacad4a 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -141,6 +141,7 @@ def _register_post_transforms(self, posttrans): Register the post transforms to the engine, will execute them as a chain when iteration completed. """ + @self.on(Events.ITERATION_COMPLETED) def run_post_transform(engine: Engine) -> None: if posttrans is None: @@ -158,9 +159,7 @@ def _register_metrics(self, k_metric, add_metrics): metrics = k_metric if add_metrics is not None and len(add_metrics) > 0: if not isinstance(add_metrics, dict): - raise TypeError( - f"additional metrics must be None or a dict but is {type(add_metrics).__name__}." - ) + raise TypeError(f"additional metrics must be None or a dict but is {type(add_metrics).__name__}.") metrics.update(add_metrics) for name, metric in metrics.items(): metric.attach(self, name) @@ -183,8 +182,6 @@ def _register_handlers(self, handlers): for handler in handlers_: handler.attach(self) - - def run(self) -> None: """ Execute training, validation or evaluation based on Ignite Engine.