Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator
from .multi_gpu_supervised_trainer import create_multigpu_supervised_evaluator, create_multigpu_supervised_trainer
from .trainer import GanTrainer, SupervisedTrainer, Trainer
from .utils import CommonKeys, GanKeys, default_make_latent, default_prepare_batch, get_devices_spec
from .utils import CommonKeys, GanKeys, IterationEvents, default_make_latent, default_prepare_batch, get_devices_spec
30 changes: 21 additions & 9 deletions monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.utils.data import DataLoader

from monai.engines.utils import CommonKeys as Keys
from monai.engines.utils import default_prepare_batch
from monai.engines.utils import IterationEvents, default_prepare_batch
from monai.engines.workflow import Workflow
from monai.inferers import Inferer, SimpleInferer
from monai.networks.utils import eval_mode
Expand Down Expand Up @@ -164,6 +164,10 @@ def __init__(
self.network = network
self.inferer = SimpleInferer() if inferer is None else inferer

def _register_additional_events(self):
super()._register_additional_events()
self.register_events(*IterationEvents)

def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
Expand All @@ -190,15 +194,18 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict
else:
inputs, targets, args, kwargs = batch

# put iteration outputs into engine.state
engine.state.output = output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
# execute forward computation
with eval_mode(self.network):
if self.amp:
with torch.cuda.amp.autocast():
predictions = self.inferer(inputs, self.network, *args, **kwargs)
output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs)
else:
predictions = self.inferer(inputs, self.network, *args, **kwargs)
output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs)
engine.fire_event(IterationEvents.FORWARD_COMPLETED)

return {Keys.IMAGE: inputs, Keys.LABEL: targets, Keys.PRED: predictions}
return output


class EnsembleEvaluator(Evaluator):
Expand Down Expand Up @@ -266,6 +273,10 @@ def __init__(
self.pred_keys = ensure_tuple(pred_keys)
self.inferer = SimpleInferer() if inferer is None else inferer

def _register_additional_events(self):
super()._register_additional_events()
self.register_events(*IterationEvents)

def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
Expand Down Expand Up @@ -295,14 +306,15 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict
else:
inputs, targets, args, kwargs = batch

# execute forward computation
predictions = {Keys.IMAGE: inputs, Keys.LABEL: targets}
# put iteration outputs into engine.state
engine.state.output = output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
for idx, network in enumerate(self.networks):
with eval_mode(network):
if self.amp:
with torch.cuda.amp.autocast():
predictions.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)})
output.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)})
else:
predictions.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)})
output.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)})
engine.fire_event(IterationEvents.FORWARD_COMPLETED)

return predictions
return output
29 changes: 21 additions & 8 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch.utils.data import DataLoader

from monai.engines.utils import CommonKeys as Keys
from monai.engines.utils import GanKeys, default_make_latent, default_prepare_batch
from monai.engines.utils import GanKeys, IterationEvents, default_make_latent, default_prepare_batch
from monai.engines.workflow import Workflow
from monai.inferers import Inferer, SimpleInferer
from monai.transforms import Transform
Expand Down Expand Up @@ -121,6 +121,10 @@ def __init__(
self.loss_function = loss_function
self.inferer = SimpleInferer() if inferer is None else inferer

def _register_additional_events(self):
super()._register_additional_events()
self.register_events(*IterationEvents)

def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
"""
Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine.
Expand All @@ -147,23 +151,32 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
kwargs: Dict = {}
else:
inputs, targets, args, kwargs = batch
# put iteration outputs into engine.state
engine.state.output = output = {Keys.IMAGE: inputs, Keys.LABEL: targets}

def _compute_pred_loss():
output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs)
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
output[Keys.LOSS] = self.loss_function(output[Keys.PRED], targets).mean()
engine.fire_event(IterationEvents.LOSS_COMPLETED)

self.network.train()
self.optimizer.zero_grad()
if self.amp and self.scaler is not None:
with torch.cuda.amp.autocast():
predictions = self.inferer(inputs, self.network, *args, **kwargs)
loss = self.loss_function(predictions, targets).mean()
self.scaler.scale(loss).backward()
_compute_pred_loss()
self.scaler.scale(output[Keys.LOSS]).backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
predictions = self.inferer(inputs, self.network, *args, **kwargs)
loss = self.loss_function(predictions, targets).mean()
loss.backward()
_compute_pred_loss()
output[Keys.LOSS].backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
self.optimizer.step()
engine.fire_event(IterationEvents.OPTIMIZER_COMPLETED)

return {Keys.IMAGE: inputs, Keys.LABEL: targets, Keys.PRED: predictions, Keys.LOSS: loss.item()}
return output


class GanTrainer(Trainer):
Expand Down
36 changes: 34 additions & 2 deletions monai/engines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union

import torch

__all__ = ["CommonKeys", "GanKeys", "get_devices_spec", "default_prepare_batch", "default_make_latent"]
from monai.utils import exact_version, optional_import

if TYPE_CHECKING:
from ignite.engine import EventEnum
else:
EventEnum, _ = optional_import("ignite.engine", "0.4.2", exact_version, "EventEnum")

__all__ = [
"IterationEvents",
"CommonKeys",
"GanKeys",
"get_devices_spec",
"default_prepare_batch",
"default_make_latent",
]


class IterationEvents(EventEnum):
"""
Addtional Events engine can register and trigger in the iteration process.
Refer to the example in ignite: https://github.com/pytorch/ignite/blob/master/ignite/engine/events.py#L146
These Events can be triggered during training iteration:
`FORWARD_COMPLETED` is the Event when `network(image, label)` completed.
`LOSS_COMPLETED` is the Event when `loss(pred, label)` completed.
`BACKWARD_COMPLETED` is the Event when `loss.backward()` completed.

"""

FORWARD_COMPLETED = "forward_completed"
LOSS_COMPLETED = "loss_completed"
BACKWARD_COMPLETED = "backward_completed"
OPTIMIZER_COMPLETED = "optimizer_completed"


class CommonKeys:
Expand All @@ -36,6 +67,7 @@ class CommonKeys:
class GanKeys:
"""
A set of common keys for generative adversarial networks.

"""

REALS = "reals"
Expand Down
88 changes: 56 additions & 32 deletions monai/engines/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,44 +119,68 @@ def set_sampler_epoch(engine: Engine):
self.data_loader = data_loader
self.non_blocking = non_blocking
self.prepare_batch = prepare_batch
self.amp = amp

self._register_additional_events()
if post_transform is not None:
self._register_post_transforms(post_transform)
if key_metric is not None:
self._register_metrics(key_metric, additional_metrics)
if handlers is not None:
self._register_handlers(handlers)

@self.on(Events.ITERATION_COMPLETED)
def run_post_transform(engine: Engine) -> None:
if post_transform is None:
raise AssertionError
engine.state.output = apply_transform(post_transform, engine.state.output)
def _register_additional_events(self):
"""
Register more ignite Events to the engine.

if key_metric is not None:
"""
pass

if not isinstance(key_metric, dict):
raise TypeError(f"key_metric must be None or a dict but is {type(key_metric).__name__}.")
self.state.key_metric_name = list(key_metric.keys())[0]
metrics = key_metric
if additional_metrics is not None and len(additional_metrics) > 0:
if not isinstance(additional_metrics, dict):
raise TypeError(
f"additional_metrics must be None or a dict but is {type(additional_metrics).__name__}."
)
metrics.update(additional_metrics)
for name, metric in metrics.items():
metric.attach(self, name)

@self.on(Events.EPOCH_COMPLETED)
def _compare_metrics(engine: Engine) -> None:
if engine.state.key_metric_name is not None:
current_val_metric = engine.state.metrics[engine.state.key_metric_name]
if current_val_metric > engine.state.best_metric:
self.logger.info(f"Got new best metric of {engine.state.key_metric_name}: {current_val_metric}")
engine.state.best_metric = current_val_metric
engine.state.best_metric_epoch = engine.state.epoch
def _register_post_transforms(self, posttrans):
"""
Register the post transforms to the engine, will execute them as a chain when iteration completed.

if handlers is not None:
handlers_ = ensure_tuple(handlers)
for handler in handlers_:
handler.attach(self)
self.amp = amp
"""

@self.on(Events.ITERATION_COMPLETED)
def run_post_transform(engine: Engine) -> None:
if posttrans is None:
raise AssertionError
engine.state.output = apply_transform(posttrans, engine.state.output)

def _register_metrics(self, k_metric, add_metrics):
"""
Register the key metric and additional metrics to the engine, supports ignite Metrics.

"""
if not isinstance(k_metric, dict):
raise TypeError(f"key_metric must be None or a dict but is {type(k_metric).__name__}.")
self.state.key_metric_name = list(k_metric.keys())[0]
metrics = k_metric
if add_metrics is not None and len(add_metrics) > 0:
if not isinstance(add_metrics, dict):
raise TypeError(f"additional metrics must be None or a dict but is {type(add_metrics).__name__}.")
metrics.update(add_metrics)
for name, metric in metrics.items():
metric.attach(self, name)

@self.on(Events.EPOCH_COMPLETED)
def _compare_metrics(engine: Engine) -> None:
if engine.state.key_metric_name is not None:
current_val_metric = engine.state.metrics[engine.state.key_metric_name]
if current_val_metric > engine.state.best_metric:
self.logger.info(f"Got new best metric of {engine.state.key_metric_name}: {current_val_metric}")
engine.state.best_metric = current_val_metric
engine.state.best_metric_epoch = engine.state.epoch

def _register_handlers(self, handlers):
"""
Register the handlers to the engine, supports ignite Handlers with `attach` API.

"""
handlers_ = ensure_tuple(handlers)
for handler in handlers_:
handler.attach(self)

def run(self) -> None:
"""
Expand Down
32 changes: 31 additions & 1 deletion tests/test_integration_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import monai
from monai.data import create_test_image_3d
from monai.engines import SupervisedEvaluator, SupervisedTrainer
from monai.engines import IterationEvents, SupervisedEvaluator, SupervisedTrainer
from monai.handlers import (
CheckpointLoader,
CheckpointSaver,
Expand Down Expand Up @@ -113,13 +113,22 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4):
KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
]
)

class _TestEvalIterEvents:
def attach(self, engine):
engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed)

def _forward_completed(self, engine):
pass

val_handlers = [
StatsHandler(output_transform=lambda x: None),
TensorBoardStatsHandler(log_dir=root_dir, output_transform=lambda x: None),
TensorBoardImageHandler(
log_dir=root_dir, batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"]
),
CheckpointSaver(save_dir=root_dir, save_dict={"net": net}, save_key_metric=True),
_TestEvalIterEvents(),
]

evaluator = SupervisedEvaluator(
Expand All @@ -143,12 +152,33 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4):
KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
]
)

class _TestTrainIterEvents:
def attach(self, engine):
engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed)
engine.add_event_handler(IterationEvents.LOSS_COMPLETED, self._loss_completed)
engine.add_event_handler(IterationEvents.BACKWARD_COMPLETED, self._backward_completed)
engine.add_event_handler(IterationEvents.OPTIMIZER_COMPLETED, self._optimizer_completed)

def _forward_completed(self, engine):
pass

def _loss_completed(self, engine):
pass

def _backward_completed(self, engine):
pass

def _optimizer_completed(self, engine):
pass

train_handlers = [
LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
TensorBoardStatsHandler(log_dir=root_dir, tag_name="train_loss", output_transform=lambda x: x["loss"]),
CheckpointSaver(save_dir=root_dir, save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True),
_TestTrainIterEvents(),
]

trainer = SupervisedTrainer(
Expand Down