diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index a629b28b27..080e7e138c 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -110,3 +110,8 @@ SmartCache handler ------------------ .. autoclass:: SmartCacheHandler :members: + +EarlyStop handler +----------------- +.. autoclass:: EarlyStopHandler + :members: diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 5669e8a9ee..a1f86310ae 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -13,6 +13,7 @@ from .checkpoint_saver import CheckpointSaver from .classification_saver import ClassificationSaver from .confusion_matrix import ConfusionMatrix +from .earlystop_handler import EarlyStopHandler from .hausdorff_distance import HausdorffDistance from .iteration_metric import IterationMetric from .lr_schedule_handler import LrScheduleHandler diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py index fd80182ba2..68857e17ff 100644 --- a/monai/handlers/checkpoint_saver.py +++ b/monai/handlers/checkpoint_saver.py @@ -17,7 +17,6 @@ Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") Checkpoint, _ = optional_import("ignite.handlers", "0.4.4", exact_version, "Checkpoint") -BaseSaveHandler, _ = optional_import("ignite.handlers.checkpoint", "0.4.4", exact_version, "BaseSaveHandler") if TYPE_CHECKING: from ignite.engine import Engine diff --git a/monai/handlers/earlystop_handler.py b/monai/handlers/earlystop_handler.py new file mode 100644 index 0000000000..99e072b81f --- /dev/null +++ b/monai/handlers/earlystop_handler.py @@ -0,0 +1,95 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Callable, Optional + +from monai.utils import exact_version, optional_import + +Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") +EarlyStopping, _ = optional_import("ignite.handlers", "0.4.4", exact_version, "EarlyStopping") + +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") + + +class EarlyStopHandler: + """ + EarlyStopHandler acts as an Ignite handler to stop training if no improvement after a given number of events. + It‘s based on the `EarlyStopping` handler in ignite. + + Args: + patience: number of events to wait if no improvement and then stop the training. + score_function: It should be a function taking a single argument, an :class:`~ignite.engine.engine.Engine` + object that the handler attached, can be a trainer or validator, and return a score `float`. + an improvement is considered if the score is higher. + trainer: trainer engine to stop the run if no improvement, if None, must call `set_trainer()` before training. + min_delta: a minimum increase in the score to qualify as an improvement, + i.e. an increase of less than or equal to `min_delta`, will count as no improvement. + cumulative_delta: if True, `min_delta` defines an increase since the last `patience` reset, otherwise, + it defines an increase after the last event, default to False. + epoch_level: check early stopping for every epoch or every iteration of the attached engine, + `True` is epoch level, `False` is iteration level, defaut to epoch level. + + Note: + If in distributed training and uses loss value of every iteration to detect earlystopping, + the values may be different in different ranks. + User may attach this handler to validator engine to detect validation metrics and stop the training, + in this case, the `score_function` is executed on validator engine and `trainer` is the trainer engine. + + """ + + def __init__( + self, + patience: int, + score_function: Callable, + trainer: Optional[Engine] = None, + min_delta: float = 0.0, + cumulative_delta: bool = False, + epoch_level: bool = True, + ) -> None: + self.patience = patience + self.score_function = score_function + self.min_delta = min_delta + self.cumulative_delta = cumulative_delta + self.epoch_level = epoch_level + self._handler = None + + if trainer is not None: + self.set_trainer(trainer=trainer) + + def attach(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + if self.epoch_level: + engine.add_event_handler(Events.EPOCH_COMPLETED, self) + else: + engine.add_event_handler(Events.ITERATION_COMPLETED, self) + + def set_trainer(self, trainer: Engine): + """ + Set trainer to execute early stop if not setting properly in `__init__()`. + """ + self._handler = EarlyStopping( + patience=self.patience, + score_function=self.score_function, + trainer=trainer, + min_delta=self.min_delta, + cumulative_delta=self.cumulative_delta, + ) + + def __call__(self, engine: Engine) -> None: + if self._handler is None: + raise RuntimeError("please set trainer in __init__() or call set_trainer() before training.") + self._handler(engine) diff --git a/monai/handlers/metric_logger.py b/monai/handlers/metric_logger.py index 778ec13900..f9a3913c56 100644 --- a/monai/handlers/metric_logger.py +++ b/monai/handlers/metric_logger.py @@ -48,7 +48,7 @@ class MetricLogger: logger = MetricLogger(evaluator=evaluator) # construct the trainer with the logger passed in as a handler so that it logs loss values - trainer = SupervisedTrainer(..., train_handlers=[logger, ValidationHandler(evaluator, 1)]) + trainer = SupervisedTrainer(..., train_handlers=[logger, ValidationHandler(1, evaluator)]) # run training, logger.loss will be a list of (iteration, loss) values, logger.metrics a dict with key # "val_mean_dice" storing a list of (iteration, metric) values diff --git a/monai/handlers/validation_handler.py b/monai/handlers/validation_handler.py index 4458a17380..fbd4b7862e 100644 --- a/monai/handlers/validation_handler.py +++ b/monai/handlers/validation_handler.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from monai.engines.evaluator import Evaluator from monai.utils import exact_version, optional_import @@ -28,11 +28,12 @@ class ValidationHandler: """ - def __init__(self, validator: Evaluator, interval: int, epoch_level: bool = True) -> None: + def __init__(self, interval: int, validator: Optional[Evaluator] = None, epoch_level: bool = True) -> None: """ Args: - validator: run the validator when trigger validation, suppose to be Evaluator. interval: do validation every N epochs or every N iterations during training. + validator: run the validator when trigger validation, suppose to be Evaluator. + if None, should call `set_validator()` before training. epoch_level: execute validation every N epochs or N iterations. `True` is epoch level, `False` is iteration level. @@ -40,12 +41,20 @@ def __init__(self, validator: Evaluator, interval: int, epoch_level: bool = True TypeError: When ``validator`` is not a ``monai.engines.evaluator.Evaluator``. """ - if not isinstance(validator, Evaluator): + if validator is not None and not isinstance(validator, Evaluator): raise TypeError(f"validator must be a monai.engines.evaluator.Evaluator but is {type(validator).__name__}.") self.validator = validator self.interval = interval self.epoch_level = epoch_level + def set_validator(self, validator: Evaluator): + """ + Set validator if not setting in the __init__(). + """ + if not isinstance(validator, Evaluator): + raise TypeError(f"validator must be a monai.engines.evaluator.Evaluator but is {type(validator).__name__}.") + self.validator = validator + def attach(self, engine: Engine) -> None: """ Args: @@ -61,4 +70,6 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ + if self.validator is None: + raise RuntimeError("please set validator in __init__() or call `set_validator()` before training.") self.validator.run(engine.state.epoch) diff --git a/tests/min_tests.py b/tests/min_tests.py index 06231af0a1..abb5b73764 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -112,6 +112,7 @@ def run_testsuit(): "test_save_imaged", "test_ensure_channel_first", "test_ensure_channel_firstd", + "test_handler_early_stop", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_handler_early_stop.py b/tests/test_handler_early_stop.py new file mode 100644 index 0000000000..efe8e89825 --- /dev/null +++ b/tests/test_handler_early_stop.py @@ -0,0 +1,66 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from ignite.engine import Engine, Events + +from monai.handlers import EarlyStopHandler + + +class TestHandlerEarlyStop(unittest.TestCase): + def test_early_stop_train_loss(self): + def _train_func(engine, batch): + return {"loss": 1.5} + + trainer = Engine(_train_func) + EarlyStopHandler( + patience=5, + score_function=lambda x: x.state.output["loss"], + trainer=trainer, + epoch_level=False, + ).attach(trainer) + + trainer.run(range(4), max_epochs=2) + self.assertEqual(trainer.state.iteration, 6) + self.assertEqual(trainer.state.epoch, 2) + + def test_early_stop_val_metric(self): + def _train_func(engine, batch): + pass + + trainer = Engine(_train_func) + validator = Engine(_train_func) + validator.state.metrics["val_acc"] = 0.90 + + @trainer.on(Events.EPOCH_COMPLETED) + def run_validation(engine): + validator.state.metrics["val_acc"] += 0.01 + validator.run(range(3)) + + handler = EarlyStopHandler( + patience=3, + score_function=lambda x: x.state.metrics["val_acc"], + trainer=None, + min_delta=0.1, + cumulative_delta=True, + epoch_level=True, + ) + handler.attach(validator) + handler.set_trainer(trainer=trainer) + + trainer.run(range(3), max_epochs=5) + self.assertEqual(trainer.state.iteration, 12) + self.assertEqual(trainer.state.epoch, 4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_prob_map_producer.py b/tests/test_handler_prob_map_producer.py index 8bf42131b4..4f719fccc0 100644 --- a/tests/test_handler_prob_map_producer.py +++ b/tests/test_handler_prob_map_producer.py @@ -82,8 +82,9 @@ def inference(enging, batch): evaluator = TestEvaluator(torch.device("cpu:0"), data_loader, size, val_handlers=[prob_map_gen]) # set up validation handler - validation = ValidationHandler(evaluator, interval=1) + validation = ValidationHandler(interval=1, validator=None) validation.attach(engine) + validation.set_validator(validator=evaluator) engine.run(data_loader) diff --git a/tests/test_handler_validation.py b/tests/test_handler_validation.py index 11a51c7213..06f400109d 100644 --- a/tests/test_handler_validation.py +++ b/tests/test_handler_validation.py @@ -37,7 +37,7 @@ def _train_func(engine, batch): # set up testing handler val_data_loader = torch.utils.data.DataLoader(Dataset(data)) evaluator = TestEvaluator(torch.device("cpu:0"), val_data_loader) - saver = ValidationHandler(evaluator, interval=2) + saver = ValidationHandler(interval=2, validator=evaluator) saver.attach(engine) engine.run(data, max_epochs=5)