From 23cde800599c1453c49cd364cc469895f0e18474 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 22 Feb 2021 12:03:54 +0800 Subject: [PATCH 1/3] [DLMED] add support to resume metrics in CheckpointSaver Signed-off-by: Nic Ma --- monai/handlers/checkpoint_saver.py | 32 +++++++++++++++ tests/test_handler_checkpoint_saver.py | 54 +++++++++++++++++++++++--- 2 files changed, 81 insertions(+), 5 deletions(-) diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py index 8052e21cb6..ba3a76171f 100644 --- a/monai/handlers/checkpoint_saver.py +++ b/monai/handlers/checkpoint_saver.py @@ -10,6 +10,7 @@ # limitations under the License. import logging +import warnings from typing import TYPE_CHECKING, Dict, Optional from monai.utils import exact_version, optional_import @@ -55,6 +56,10 @@ class CheckpointSaver: metric in descending order. key_metric_filename: set a fixed filename to set the best metric model, if not None, `key_metric_n_saved` should be 1 and only keep the best metric model. + key_metric_save_state: whether to save the tracking list of key metric in the checkpoint file. + if `True`, then will save an object in the checkpoint file with key `checkpointer` to be consistent + with ignite: https://github.com/pytorch/ignite/blob/master/ignite/handlers/checkpoint.py#L99. + typically, it's used to resume training and compare current metric with previous N values. epoch_level: save checkpoint during training for every N epochs or every N iterations. `True` is epoch level, `False` is iteration level. save_interval: save checkpoint every N epochs, default is 0 to save no checkpoint. @@ -84,6 +89,7 @@ def __init__( key_metric_name: Optional[str] = None, key_metric_n_saved: int = 1, key_metric_filename: Optional[str] = None, + key_metric_save_state: bool = False, epoch_level: bool = True, save_interval: int = 0, n_saved: Optional[int] = None, @@ -156,6 +162,7 @@ def _score_func(engine: Engine): score_function=_score_func, score_name="key_metric", n_saved=key_metric_n_saved, + include_self=key_metric_save_state, ) if save_interval > 0: @@ -172,6 +179,31 @@ def _interval_func(engine: Engine): n_saved=n_saved, ) + def load_state_dict(self, state_dict: Dict) -> None: + """ + Utility to resume the internal state of key metric tracking list if configured to save + checkpoints based on the key metric value. + Note to set `key_metric_save_state=True` when saving the previous checkpoint. + + Example:: + CheckpointSaver( + ... + save_key_metric=True, + key_metric_save_state=True, # config to also save the state of this saver + ).attach(engine) + engine.run(...) + + # resumed training with a new CheckpointSaver + saver = CheckpointSaver(save_key_metric=True, ...) + # load the previous key metric tracking list into saver + CheckpointLoader("/test/model.pt"), {"checkpointer": saver}).attach(engine) + + """ + if self._key_metric_checkpoint is not None: + self._key_metric_checkpoint.load_state_dict(state_dict) + else: + warnings.warn("no key metric checkpoint saver to resume the key metric tracking list.") + def attach(self, engine: Engine) -> None: """ Args: diff --git a/tests/test_handler_checkpoint_saver.py b/tests/test_handler_checkpoint_saver.py index 0ae8be1e73..0ad77fcab7 100644 --- a/tests/test_handler_checkpoint_saver.py +++ b/tests/test_handler_checkpoint_saver.py @@ -20,9 +20,9 @@ from ignite.engine import Engine from parameterized import parameterized -from monai.handlers import CheckpointSaver +from monai.handlers import CheckpointSaver, CheckpointLoader -TEST_CASE_1 = [True, None, False, None, 1, None, True, 0, None, ["test_checkpoint_final_iteration=40.pt"]] +TEST_CASE_1 = [True, None, False, None, 1, None, False, True, 0, None, ["test_checkpoint_final_iteration=40.pt"]] TEST_CASE_2 = [ False, @@ -31,6 +31,7 @@ "val_loss", 2, None, + False, True, 0, None, @@ -44,6 +45,7 @@ None, 1, None, + False, True, 2, 2, @@ -58,16 +60,17 @@ 1, None, False, + False, 10, 2, ["test_checkpoint_iteration=30.pt", "test_checkpoint_iteration=40.pt"], ] -TEST_CASE_5 = [True, None, False, None, 1, None, True, 0, None, ["test_checkpoint_final_iteration=40.pt"], True] +TEST_CASE_5 = [True, None, False, None, 1, None, False, True, 0, None, ["test_checkpoint_final_iteration=40.pt"], True] -TEST_CASE_6 = [True, "final_model.pt", False, None, 1, None, True, 0, None, ["final_model.pt"]] +TEST_CASE_6 = [True, "final_model.pt", False, None, 1, None, False, True, 0, None, ["final_model.pt"]] -TEST_CASE_7 = [False, None, True, "val_loss", 1, "model.pt", True, 0, None, ["model.pt"]] +TEST_CASE_7 = [False, None, True, "val_loss", 1, "model.pt", False, True, 0, None, ["model.pt"]] class TestHandlerCheckpointSaver(unittest.TestCase): @@ -80,6 +83,7 @@ def test_file( key_metric_name, key_metric_n_saved, key_metric_filename, + key_metric_save_state, epoch_level, save_interval, n_saved, @@ -112,6 +116,7 @@ def _train_func(engine, batch): key_metric_name, key_metric_n_saved, key_metric_filename, + key_metric_save_state, epoch_level, save_interval, n_saved, @@ -141,6 +146,45 @@ def _train_func(engine, batch): engine.run(range(3), max_epochs=2) self.assertTrue(os.path.exists(os.path.join(tempdir, "net_final_iteration=1.pt"))) + def test_load_state_dict(self): + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + net = torch.nn.PReLU() + + # set up engine + def _train_func(engine, batch): + engine.state.metrics["val_loss"] = engine.state.iteration + + engine = Engine(_train_func) + + # set up testing handler + with tempfile.TemporaryDirectory() as tempdir: + engine = Engine(_train_func) + CheckpointSaver( + save_dir=tempdir, + save_dict={"net": net}, + save_key_metric=True, + key_metric_name="val_loss", + key_metric_n_saved=2, + key_metric_save_state=True, + ).attach(engine) + engine.run(range(3), max_epochs=2) + + saver = CheckpointSaver( + save_dir=tempdir, + save_dict={"net": net}, + save_key_metric=True, + key_metric_name="val_loss", + key_metric_n_saved=2, + ) + engine = Engine(_train_func) + CheckpointLoader(os.path.join(tempdir, "net_key_metric=6.pt"), {"checkpointer": saver}).attach(engine) + engine.run(range(1), max_epochs=1) + + resumed = saver._key_metric_checkpoint._saved + for i in range(2): + self.assertEqual(resumed[i].priority, 3 * (i + 1)) + self.assertEqual(resumed[i].filename, f"net_key_metric={3 * (i + 1)}.pt") + if __name__ == "__main__": unittest.main() From 558c9e52859d2ad54dbce6a5b5a7ca09a138f690 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Mon, 22 Feb 2021 04:10:32 +0000 Subject: [PATCH 2/3] [MONAI] python code formatting Signed-off-by: monai-bot --- tests/test_handler_checkpoint_saver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_handler_checkpoint_saver.py b/tests/test_handler_checkpoint_saver.py index 0ad77fcab7..5c2b750a57 100644 --- a/tests/test_handler_checkpoint_saver.py +++ b/tests/test_handler_checkpoint_saver.py @@ -20,7 +20,7 @@ from ignite.engine import Engine from parameterized import parameterized -from monai.handlers import CheckpointSaver, CheckpointLoader +from monai.handlers import CheckpointLoader, CheckpointSaver TEST_CASE_1 = [True, None, False, None, 1, None, False, True, 0, None, ["test_checkpoint_final_iteration=40.pt"]] From 7c82a63476c6634abf87934cc8a41132ab4fdfb2 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 22 Feb 2021 12:23:34 +0800 Subject: [PATCH 3/3] [DLMED] fix doc-string Signed-off-by: Nic Ma --- monai/handlers/checkpoint_saver.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py index ba3a76171f..1808e6b251 100644 --- a/monai/handlers/checkpoint_saver.py +++ b/monai/handlers/checkpoint_saver.py @@ -186,6 +186,7 @@ def load_state_dict(self, state_dict: Dict) -> None: Note to set `key_metric_save_state=True` when saving the previous checkpoint. Example:: + CheckpointSaver( ... save_key_metric=True,