From efce731f210964cad248d812c5afc927300efdeb Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 2 Apr 2021 11:51:22 +0800 Subject: [PATCH 1/8] [DLMED] add EarlyStop handler Signed-off-by: Nic Ma --- docs/source/handlers.rst | 5 ++ monai/handlers/__init__.py | 1 + monai/handlers/earlystop_handler.py | 76 +++++++++++++++++++++++++++++ tests/test_handler_early_stop.py | 65 ++++++++++++++++++++++++ 4 files changed, 147 insertions(+) create mode 100644 monai/handlers/earlystop_handler.py create mode 100644 tests/test_handler_early_stop.py diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index a629b28b27..080e7e138c 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -110,3 +110,8 @@ SmartCache handler ------------------ .. autoclass:: SmartCacheHandler :members: + +EarlyStop handler +----------------- +.. autoclass:: EarlyStopHandler + :members: diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 5669e8a9ee..a1f86310ae 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -13,6 +13,7 @@ from .checkpoint_saver import CheckpointSaver from .classification_saver import ClassificationSaver from .confusion_matrix import ConfusionMatrix +from .earlystop_handler import EarlyStopHandler from .hausdorff_distance import HausdorffDistance from .iteration_metric import IterationMetric from .lr_schedule_handler import LrScheduleHandler diff --git a/monai/handlers/earlystop_handler.py b/monai/handlers/earlystop_handler.py new file mode 100644 index 0000000000..1a646dc688 --- /dev/null +++ b/monai/handlers/earlystop_handler.py @@ -0,0 +1,76 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Optional, Callable + +from monai.utils import exact_version, optional_import + +Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") +EarlyStopping, _ = optional_import("ignite.handlers", "0.4.4", exact_version, "EarlyStopping") +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") + + +class EarlyStopHandler(EarlyStopping): + """ + EarlyStopHandler acts as an Ignite handler to stop training if no improvement after a given number of events. + It inherits the `EarlyStopping` handler in ignite. + + Args: + patience: number of events to wait if no improvement and then stop the training. + score_function: It should be a function taking a single argument, an :class:`~ignite.engine.engine.Engine` + object that the handler attached, can be a trainer or validator, and return a score `float`. + an improvement is considered if the score is higher. + trainer: trainer engine to stop the run if no improvement. + min_delta: a minimum increase in the score to qualify as an improvement, + i.e. an increase of less than or equal to `min_delta`, will count as no improvement. + cumulative_delta: if True, `min_delta` defines an increase since the last `patience` reset, otherwise, + it defines an increase after the last event, default to False. + epoch_level: check early stopping for every epoch or every iteration of the attached engine, + `True` is epoch level, `False` is iteration level, defaut to epoch level. + + Note: + If in distributed training and uses loss value of every iteration to detect earlystopping, + the values may be different in different ranks. + User may attach this handler to validator engine to detect validation metrics and stop the training, + in this case, the `score_function` is executed on validator engine and `trainer` is the trainer engine. + + """ + + def __init__( + self, + patience: int, + score_function: Callable, + trainer: Engine, + min_delta: float = 0.0, + cumulative_delta: bool = False, + epoch_level: bool = True, + ) -> None: + super().__init__( + patience=patience, + score_function=score_function, + trainer=trainer, + min_delta=min_delta, + cumulative_delta=cumulative_delta, + ) + self.epoch_level = epoch_level + + def attach(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + if self.epoch_level: + engine.add_event_handler(Events.EPOCH_COMPLETED, self) + else: + engine.add_event_handler(Events.ITERATION_COMPLETED, self) diff --git a/tests/test_handler_early_stop.py b/tests/test_handler_early_stop.py new file mode 100644 index 0000000000..fb7bb2fec7 --- /dev/null +++ b/tests/test_handler_early_stop.py @@ -0,0 +1,65 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from ignite.engine import Engine, Events +from monai.handlers import EarlyStopHandler + + +class TestHandlerEarlyStop(unittest.TestCase): + def test_early_stop_train_loss(self): + + def _train_func(engine, batch): + return {"loss": 1.5} + + trainer = Engine(_train_func) + EarlyStopHandler( + patience=5, + score_function=lambda x: x.state.output["loss"], + trainer=trainer, + epoch_level=False, + ).attach(trainer) + + trainer.run(range(4), max_epochs=2) + self.assertEqual(trainer.state.iteration, 6) + self.assertEqual(trainer.state.epoch, 2) + + def test_early_stop_val_metric(self): + + def _train_func(engine, batch): + pass + + trainer = Engine(_train_func) + validator = Engine(_train_func) + validator.state.metrics["val_acc"] = 0.90 + + @trainer.on(Events.EPOCH_COMPLETED) + def run_validation(engine): + validator.state.metrics["val_acc"] += 0.01 + validator.run(range(3)) + + EarlyStopHandler( + patience=3, + score_function=lambda x: x.state.metrics["val_acc"], + trainer=trainer, + min_delta=0.1, + cumulative_delta=True, + epoch_level=True, + ).attach(validator) + + trainer.run(range(3), max_epochs=5) + self.assertEqual(trainer.state.iteration, 12) + self.assertEqual(trainer.state.epoch, 4) + + +if __name__ == "__main__": + unittest.main() From feb201b54ed5c321f7356aecc0356967594d761c Mon Sep 17 00:00:00 2001 From: monai-bot Date: Fri, 2 Apr 2021 03:57:15 +0000 Subject: [PATCH 2/8] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/handlers/earlystop_handler.py | 2 +- tests/test_handler_early_stop.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/handlers/earlystop_handler.py b/monai/handlers/earlystop_handler.py index 1a646dc688..8f90573a47 100644 --- a/monai/handlers/earlystop_handler.py +++ b/monai/handlers/earlystop_handler.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Optional, Callable +from typing import TYPE_CHECKING, Callable, Optional from monai.utils import exact_version, optional_import diff --git a/tests/test_handler_early_stop.py b/tests/test_handler_early_stop.py index fb7bb2fec7..28bf5ea852 100644 --- a/tests/test_handler_early_stop.py +++ b/tests/test_handler_early_stop.py @@ -12,12 +12,12 @@ import unittest from ignite.engine import Engine, Events + from monai.handlers import EarlyStopHandler class TestHandlerEarlyStop(unittest.TestCase): def test_early_stop_train_loss(self): - def _train_func(engine, batch): return {"loss": 1.5} @@ -34,7 +34,6 @@ def _train_func(engine, batch): self.assertEqual(trainer.state.epoch, 2) def test_early_stop_val_metric(self): - def _train_func(engine, batch): pass From 468ee177e83824f5e53b607e2b0a734c2db17c62 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 2 Apr 2021 12:36:51 +0800 Subject: [PATCH 3/8] [DLMED] enhance validation handler Signed-off-by: Nic Ma --- monai/handlers/metric_logger.py | 2 +- monai/handlers/validation_handler.py | 17 +++++++++++++---- tests/min_tests.py | 1 + tests/test_handler_prob_map_producer.py | 3 ++- tests/test_handler_validation.py | 2 +- 5 files changed, 18 insertions(+), 7 deletions(-) diff --git a/monai/handlers/metric_logger.py b/monai/handlers/metric_logger.py index 778ec13900..f9a3913c56 100644 --- a/monai/handlers/metric_logger.py +++ b/monai/handlers/metric_logger.py @@ -48,7 +48,7 @@ class MetricLogger: logger = MetricLogger(evaluator=evaluator) # construct the trainer with the logger passed in as a handler so that it logs loss values - trainer = SupervisedTrainer(..., train_handlers=[logger, ValidationHandler(evaluator, 1)]) + trainer = SupervisedTrainer(..., train_handlers=[logger, ValidationHandler(1, evaluator)]) # run training, logger.loss will be a list of (iteration, loss) values, logger.metrics a dict with key # "val_mean_dice" storing a list of (iteration, metric) values diff --git a/monai/handlers/validation_handler.py b/monai/handlers/validation_handler.py index 4458a17380..7b4721fa09 100644 --- a/monai/handlers/validation_handler.py +++ b/monai/handlers/validation_handler.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from monai.engines.evaluator import Evaluator from monai.utils import exact_version, optional_import @@ -28,11 +28,12 @@ class ValidationHandler: """ - def __init__(self, validator: Evaluator, interval: int, epoch_level: bool = True) -> None: + def __init__(self, interval: int, validator: Optional[Evaluator] = None, epoch_level: bool = True) -> None: """ Args: - validator: run the validator when trigger validation, suppose to be Evaluator. interval: do validation every N epochs or every N iterations during training. + validator: run the validator when trigger validation, suppose to be Evaluator. + if None, should call `set_validator()` before training. epoch_level: execute validation every N epochs or N iterations. `True` is epoch level, `False` is iteration level. @@ -40,12 +41,20 @@ def __init__(self, validator: Evaluator, interval: int, epoch_level: bool = True TypeError: When ``validator`` is not a ``monai.engines.evaluator.Evaluator``. """ - if not isinstance(validator, Evaluator): + if validator is not None and not isinstance(validator, Evaluator): raise TypeError(f"validator must be a monai.engines.evaluator.Evaluator but is {type(validator).__name__}.") self.validator = validator self.interval = interval self.epoch_level = epoch_level + def set_validator(self, validator: Evaluator): + """ + Set validator if not setting in the __init__(). + """ + if not isinstance(validator, Evaluator): + raise TypeError(f"validator must be a monai.engines.evaluator.Evaluator but is {type(validator).__name__}.") + self.validator = validator + def attach(self, engine: Engine) -> None: """ Args: diff --git a/tests/min_tests.py b/tests/min_tests.py index 06231af0a1..abb5b73764 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -112,6 +112,7 @@ def run_testsuit(): "test_save_imaged", "test_ensure_channel_first", "test_ensure_channel_firstd", + "test_handler_early_stop", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_handler_prob_map_producer.py b/tests/test_handler_prob_map_producer.py index 8bf42131b4..4f719fccc0 100644 --- a/tests/test_handler_prob_map_producer.py +++ b/tests/test_handler_prob_map_producer.py @@ -82,8 +82,9 @@ def inference(enging, batch): evaluator = TestEvaluator(torch.device("cpu:0"), data_loader, size, val_handlers=[prob_map_gen]) # set up validation handler - validation = ValidationHandler(evaluator, interval=1) + validation = ValidationHandler(interval=1, validator=None) validation.attach(engine) + validation.set_validator(validator=evaluator) engine.run(data_loader) diff --git a/tests/test_handler_validation.py b/tests/test_handler_validation.py index 11a51c7213..06f400109d 100644 --- a/tests/test_handler_validation.py +++ b/tests/test_handler_validation.py @@ -37,7 +37,7 @@ def _train_func(engine, batch): # set up testing handler val_data_loader = torch.utils.data.DataLoader(Dataset(data)) evaluator = TestEvaluator(torch.device("cpu:0"), val_data_loader) - saver = ValidationHandler(evaluator, interval=2) + saver = ValidationHandler(interval=2, validator=evaluator) saver.attach(engine) engine.run(data, max_epochs=5) From ec80a53b0dc52a3ab0e3e71b8e37d32d2f03821d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 2 Apr 2021 13:15:11 +0800 Subject: [PATCH 4/8] [DLMED] add set_trainer support Signed-off-by: Nic Ma --- monai/handlers/earlystop_handler.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/monai/handlers/earlystop_handler.py b/monai/handlers/earlystop_handler.py index 8f90573a47..315d4889de 100644 --- a/monai/handlers/earlystop_handler.py +++ b/monai/handlers/earlystop_handler.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable from monai.utils import exact_version, optional_import @@ -74,3 +74,9 @@ def attach(self, engine: Engine) -> None: engine.add_event_handler(Events.EPOCH_COMPLETED, self) else: engine.add_event_handler(Events.ITERATION_COMPLETED, self) + + def set_trainer(self, trainer: Engine): + """ + set trainer to execute early stop if not setting properly in `__init__()`. + """ + self.trainer = trainer From 3700c0541952182bc4fc9ae55bf88a9390e5e4ab Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 2 Apr 2021 13:18:32 +0800 Subject: [PATCH 5/8] [DLMED] add more check Signed-off-by: Nic Ma --- monai/handlers/earlystop_handler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/handlers/earlystop_handler.py b/monai/handlers/earlystop_handler.py index 315d4889de..1c548e5cfe 100644 --- a/monai/handlers/earlystop_handler.py +++ b/monai/handlers/earlystop_handler.py @@ -79,4 +79,6 @@ def set_trainer(self, trainer: Engine): """ set trainer to execute early stop if not setting properly in `__init__()`. """ + if not isinstance(trainer, Engine): + raise TypeError("trainer must be an instance of Engine.") self.trainer = trainer From 071e697c25bc710bd641af6576905c0373108acf Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 2 Apr 2021 14:13:17 +0800 Subject: [PATCH 6/8] [DLMED] fix flake8 issue Signed-off-by: Nic Ma --- monai/handlers/earlystop_handler.py | 4 +++- monai/handlers/validation_handler.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/monai/handlers/earlystop_handler.py b/monai/handlers/earlystop_handler.py index 1c548e5cfe..2606c5dbed 100644 --- a/monai/handlers/earlystop_handler.py +++ b/monai/handlers/earlystop_handler.py @@ -14,11 +14,13 @@ from monai.utils import exact_version, optional_import Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") -EarlyStopping, _ = optional_import("ignite.handlers", "0.4.4", exact_version, "EarlyStopping") + if TYPE_CHECKING: from ignite.engine import Engine + from ignite.handlers import EarlyStopping else: Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") + EarlyStopping, _ = optional_import("ignite.handlers", "0.4.4", exact_version, "EarlyStopping") class EarlyStopHandler(EarlyStopping): diff --git a/monai/handlers/validation_handler.py b/monai/handlers/validation_handler.py index 7b4721fa09..fbd4b7862e 100644 --- a/monai/handlers/validation_handler.py +++ b/monai/handlers/validation_handler.py @@ -70,4 +70,6 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ + if self.validator is None: + raise RuntimeError("please set validator in __init__() or call `set_validator()` before training.") self.validator.run(engine.state.epoch) From a61baead038f6303976940ef40d2452c86be23b1 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 2 Apr 2021 19:24:26 +0800 Subject: [PATCH 7/8] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/handlers/earlystop_handler.py | 42 ++++++++++++++++++----------- tests/test_handler_early_stop.py | 8 +++--- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/monai/handlers/earlystop_handler.py b/monai/handlers/earlystop_handler.py index 2606c5dbed..4b160d56de 100644 --- a/monai/handlers/earlystop_handler.py +++ b/monai/handlers/earlystop_handler.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, Optional from monai.utils import exact_version, optional_import @@ -23,17 +23,17 @@ EarlyStopping, _ = optional_import("ignite.handlers", "0.4.4", exact_version, "EarlyStopping") -class EarlyStopHandler(EarlyStopping): +class EarlyStopHandler: """ EarlyStopHandler acts as an Ignite handler to stop training if no improvement after a given number of events. - It inherits the `EarlyStopping` handler in ignite. + It‘s based on the `EarlyStopping` handler in ignite. Args: patience: number of events to wait if no improvement and then stop the training. score_function: It should be a function taking a single argument, an :class:`~ignite.engine.engine.Engine` object that the handler attached, can be a trainer or validator, and return a score `float`. an improvement is considered if the score is higher. - trainer: trainer engine to stop the run if no improvement. + trainer: trainer engine to stop the run if no improvement, if None, must call `set_trainer()` before training. min_delta: a minimum increase in the score to qualify as an improvement, i.e. an increase of less than or equal to `min_delta`, will count as no improvement. cumulative_delta: if True, `min_delta` defines an increase since the last `patience` reset, otherwise, @@ -53,19 +53,20 @@ def __init__( self, patience: int, score_function: Callable, - trainer: Engine, + trainer: Optional[Engine] = None, min_delta: float = 0.0, cumulative_delta: bool = False, epoch_level: bool = True, ) -> None: - super().__init__( - patience=patience, - score_function=score_function, - trainer=trainer, - min_delta=min_delta, - cumulative_delta=cumulative_delta, - ) + self.patience = patience + self.score_function = score_function + self.min_delta = min_delta + self.cumulative_delta = cumulative_delta self.epoch_level = epoch_level + self._handler = None + + if trainer is not None: + self.set_trainer(trainer=trainer) def attach(self, engine: Engine) -> None: """ @@ -79,8 +80,17 @@ def attach(self, engine: Engine) -> None: def set_trainer(self, trainer: Engine): """ - set trainer to execute early stop if not setting properly in `__init__()`. + Set trainer to execute early stop if not setting properly in `__init__()`. """ - if not isinstance(trainer, Engine): - raise TypeError("trainer must be an instance of Engine.") - self.trainer = trainer + self._handler = EarlyStopping( + patience=self.patience, + score_function=self.score_function, + trainer=trainer, + min_delta=self.min_delta, + cumulative_delta=self.cumulative_delta, + ) + + def __call__(self, engine: Engine) -> None: + if self._handler is None: + raise RuntimeError("please set trainer in __init__() or call set_trainer() before training.") + self._handler(engine) diff --git a/tests/test_handler_early_stop.py b/tests/test_handler_early_stop.py index 28bf5ea852..efe8e89825 100644 --- a/tests/test_handler_early_stop.py +++ b/tests/test_handler_early_stop.py @@ -46,14 +46,16 @@ def run_validation(engine): validator.state.metrics["val_acc"] += 0.01 validator.run(range(3)) - EarlyStopHandler( + handler = EarlyStopHandler( patience=3, score_function=lambda x: x.state.metrics["val_acc"], - trainer=trainer, + trainer=None, min_delta=0.1, cumulative_delta=True, epoch_level=True, - ).attach(validator) + ) + handler.attach(validator) + handler.set_trainer(trainer=trainer) trainer.run(range(3), max_epochs=5) self.assertEqual(trainer.state.iteration, 12) From f16e674aa4a6d73cd2245fc9bddb10fa18dbe72a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 2 Apr 2021 22:14:42 +0800 Subject: [PATCH 8/8] [DLMED] fix flake8 issue Signed-off-by: Nic Ma --- monai/handlers/checkpoint_saver.py | 1 - monai/handlers/earlystop_handler.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py index fd80182ba2..68857e17ff 100644 --- a/monai/handlers/checkpoint_saver.py +++ b/monai/handlers/checkpoint_saver.py @@ -17,7 +17,6 @@ Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") Checkpoint, _ = optional_import("ignite.handlers", "0.4.4", exact_version, "Checkpoint") -BaseSaveHandler, _ = optional_import("ignite.handlers.checkpoint", "0.4.4", exact_version, "BaseSaveHandler") if TYPE_CHECKING: from ignite.engine import Engine diff --git a/monai/handlers/earlystop_handler.py b/monai/handlers/earlystop_handler.py index 4b160d56de..99e072b81f 100644 --- a/monai/handlers/earlystop_handler.py +++ b/monai/handlers/earlystop_handler.py @@ -14,13 +14,12 @@ from monai.utils import exact_version, optional_import Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") +EarlyStopping, _ = optional_import("ignite.handlers", "0.4.4", exact_version, "EarlyStopping") if TYPE_CHECKING: from ignite.engine import Engine - from ignite.handlers import EarlyStopping else: Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") - EarlyStopping, _ = optional_import("ignite.handlers", "0.4.4", exact_version, "EarlyStopping") class EarlyStopHandler: