Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,8 @@ SmartCache handler
------------------
.. autoclass:: SmartCacheHandler
:members:

EarlyStop handler
-----------------
.. autoclass:: EarlyStopHandler
:members:
1 change: 1 addition & 0 deletions monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .checkpoint_saver import CheckpointSaver
from .classification_saver import ClassificationSaver
from .confusion_matrix import ConfusionMatrix
from .earlystop_handler import EarlyStopHandler
from .hausdorff_distance import HausdorffDistance
from .iteration_metric import IterationMetric
from .lr_schedule_handler import LrScheduleHandler
Expand Down
1 change: 0 additions & 1 deletion monai/handlers/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
Checkpoint, _ = optional_import("ignite.handlers", "0.4.4", exact_version, "Checkpoint")
BaseSaveHandler, _ = optional_import("ignite.handlers.checkpoint", "0.4.4", exact_version, "BaseSaveHandler")

if TYPE_CHECKING:
from ignite.engine import Engine
Expand Down
95 changes: 95 additions & 0 deletions monai/handlers/earlystop_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Callable, Optional

from monai.utils import exact_version, optional_import

Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
EarlyStopping, _ = optional_import("ignite.handlers", "0.4.4", exact_version, "EarlyStopping")

if TYPE_CHECKING:
from ignite.engine import Engine
else:
Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine")


class EarlyStopHandler:
"""
EarlyStopHandler acts as an Ignite handler to stop training if no improvement after a given number of events.
It‘s based on the `EarlyStopping` handler in ignite.

Args:
patience: number of events to wait if no improvement and then stop the training.
score_function: It should be a function taking a single argument, an :class:`~ignite.engine.engine.Engine`
object that the handler attached, can be a trainer or validator, and return a score `float`.
an improvement is considered if the score is higher.
trainer: trainer engine to stop the run if no improvement, if None, must call `set_trainer()` before training.
min_delta: a minimum increase in the score to qualify as an improvement,
i.e. an increase of less than or equal to `min_delta`, will count as no improvement.
cumulative_delta: if True, `min_delta` defines an increase since the last `patience` reset, otherwise,
it defines an increase after the last event, default to False.
epoch_level: check early stopping for every epoch or every iteration of the attached engine,
`True` is epoch level, `False` is iteration level, defaut to epoch level.

Note:
If in distributed training and uses loss value of every iteration to detect earlystopping,
the values may be different in different ranks.
User may attach this handler to validator engine to detect validation metrics and stop the training,
in this case, the `score_function` is executed on validator engine and `trainer` is the trainer engine.

"""

def __init__(
self,
patience: int,
score_function: Callable,
trainer: Optional[Engine] = None,
min_delta: float = 0.0,
cumulative_delta: bool = False,
epoch_level: bool = True,
) -> None:
self.patience = patience
self.score_function = score_function
self.min_delta = min_delta
self.cumulative_delta = cumulative_delta
self.epoch_level = epoch_level
self._handler = None

if trainer is not None:
self.set_trainer(trainer=trainer)

def attach(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
if self.epoch_level:
engine.add_event_handler(Events.EPOCH_COMPLETED, self)
else:
engine.add_event_handler(Events.ITERATION_COMPLETED, self)

def set_trainer(self, trainer: Engine):
"""
Set trainer to execute early stop if not setting properly in `__init__()`.
"""
self._handler = EarlyStopping(
patience=self.patience,
score_function=self.score_function,
trainer=trainer,
min_delta=self.min_delta,
cumulative_delta=self.cumulative_delta,
)

def __call__(self, engine: Engine) -> None:
if self._handler is None:
raise RuntimeError("please set trainer in __init__() or call set_trainer() before training.")
self._handler(engine)
2 changes: 1 addition & 1 deletion monai/handlers/metric_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class MetricLogger:
logger = MetricLogger(evaluator=evaluator)

# construct the trainer with the logger passed in as a handler so that it logs loss values
trainer = SupervisedTrainer(..., train_handlers=[logger, ValidationHandler(evaluator, 1)])
trainer = SupervisedTrainer(..., train_handlers=[logger, ValidationHandler(1, evaluator)])

# run training, logger.loss will be a list of (iteration, loss) values, logger.metrics a dict with key
# "val_mean_dice" storing a list of (iteration, metric) values
Expand Down
19 changes: 15 additions & 4 deletions monai/handlers/validation_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from monai.engines.evaluator import Evaluator
from monai.utils import exact_version, optional_import
Expand All @@ -28,24 +28,33 @@ class ValidationHandler:

"""

def __init__(self, validator: Evaluator, interval: int, epoch_level: bool = True) -> None:
def __init__(self, interval: int, validator: Optional[Evaluator] = None, epoch_level: bool = True) -> None:
"""
Args:
validator: run the validator when trigger validation, suppose to be Evaluator.
interval: do validation every N epochs or every N iterations during training.
validator: run the validator when trigger validation, suppose to be Evaluator.
if None, should call `set_validator()` before training.
epoch_level: execute validation every N epochs or N iterations.
`True` is epoch level, `False` is iteration level.

Raises:
TypeError: When ``validator`` is not a ``monai.engines.evaluator.Evaluator``.

"""
if not isinstance(validator, Evaluator):
if validator is not None and not isinstance(validator, Evaluator):
raise TypeError(f"validator must be a monai.engines.evaluator.Evaluator but is {type(validator).__name__}.")
self.validator = validator
self.interval = interval
self.epoch_level = epoch_level

def set_validator(self, validator: Evaluator):
"""
Set validator if not setting in the __init__().
"""
if not isinstance(validator, Evaluator):
raise TypeError(f"validator must be a monai.engines.evaluator.Evaluator but is {type(validator).__name__}.")
self.validator = validator

def attach(self, engine: Engine) -> None:
"""
Args:
Expand All @@ -61,4 +70,6 @@ def __call__(self, engine: Engine) -> None:
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
if self.validator is None:
raise RuntimeError("please set validator in __init__() or call `set_validator()` before training.")
self.validator.run(engine.state.epoch)
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def run_testsuit():
"test_save_imaged",
"test_ensure_channel_first",
"test_ensure_channel_firstd",
"test_handler_early_stop",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"

Expand Down
66 changes: 66 additions & 0 deletions tests/test_handler_early_stop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from ignite.engine import Engine, Events

from monai.handlers import EarlyStopHandler


class TestHandlerEarlyStop(unittest.TestCase):
def test_early_stop_train_loss(self):
def _train_func(engine, batch):
return {"loss": 1.5}

trainer = Engine(_train_func)
EarlyStopHandler(
patience=5,
score_function=lambda x: x.state.output["loss"],
trainer=trainer,
epoch_level=False,
).attach(trainer)

trainer.run(range(4), max_epochs=2)
self.assertEqual(trainer.state.iteration, 6)
self.assertEqual(trainer.state.epoch, 2)

def test_early_stop_val_metric(self):
def _train_func(engine, batch):
pass

trainer = Engine(_train_func)
validator = Engine(_train_func)
validator.state.metrics["val_acc"] = 0.90

@trainer.on(Events.EPOCH_COMPLETED)
def run_validation(engine):
validator.state.metrics["val_acc"] += 0.01
validator.run(range(3))

handler = EarlyStopHandler(
patience=3,
score_function=lambda x: x.state.metrics["val_acc"],
trainer=None,
min_delta=0.1,
cumulative_delta=True,
epoch_level=True,
)
handler.attach(validator)
handler.set_trainer(trainer=trainer)

trainer.run(range(3), max_epochs=5)
self.assertEqual(trainer.state.iteration, 12)
self.assertEqual(trainer.state.epoch, 4)


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion tests/test_handler_prob_map_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ def inference(enging, batch):
evaluator = TestEvaluator(torch.device("cpu:0"), data_loader, size, val_handlers=[prob_map_gen])

# set up validation handler
validation = ValidationHandler(evaluator, interval=1)
validation = ValidationHandler(interval=1, validator=None)
validation.attach(engine)
validation.set_validator(validator=evaluator)

engine.run(data_loader)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_handler_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _train_func(engine, batch):
# set up testing handler
val_data_loader = torch.utils.data.DataLoader(Dataset(data))
evaluator = TestEvaluator(torch.device("cpu:0"), val_data_loader)
saver = ValidationHandler(evaluator, interval=2)
saver = ValidationHandler(interval=2, validator=evaluator)
saver.attach(engine)

engine.run(data, max_epochs=5)
Expand Down