diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py index 57d8728cd4..0cc05b2dc4 100644 --- a/monai/handlers/checkpoint_saver.py +++ b/monai/handlers/checkpoint_saver.py @@ -188,6 +188,13 @@ def attach(self, engine: Engine) -> None: else: engine.add_event_handler(Events.ITERATION_COMPLETED(every=self.save_interval), self.interval_completed) + def _delete_previous_final_ckpt(self): + saved = self._final_checkpoint._saved + if len(saved) > 0: + item = saved.pop(0) + self._final_checkpoint.save_handler.remove(item.filename) + self.logger.info(f"Deleted previous saved final checkpoint: {item.filename}") + def completed(self, engine: Engine) -> None: """Callback for train or validation/evaluation completed Event. Save final checkpoint if configure save_final is True. @@ -196,6 +203,8 @@ def completed(self, engine: Engine) -> None: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ assert callable(self._final_checkpoint), "Error: _final_checkpoint function not specified." + # delete previous saved final checkpoint if existing + self._delete_previous_final_ckpt() self._final_checkpoint(engine) assert self.logger is not None assert hasattr(self.logger, "info"), "Error, provided logger has not info attribute." @@ -211,6 +220,8 @@ def exception_raised(self, engine: Engine, e: Exception) -> None: e: the exception caught in Ignite during engine.run(). """ assert callable(self._final_checkpoint), "Error: _final_checkpoint function not specified." + # delete previous saved final checkpoint if existing + self._delete_previous_final_ckpt() self._final_checkpoint(engine) assert self.logger is not None assert hasattr(self.logger, "info"), "Error, provided logger has not info attribute." diff --git a/tests/test_handler_checkpoint_saver.py b/tests/test_handler_checkpoint_saver.py index 8513b6625e..2df36d9720 100644 --- a/tests/test_handler_checkpoint_saver.py +++ b/tests/test_handler_checkpoint_saver.py @@ -117,6 +117,7 @@ def _train_func(engine, batch): n_saved, ) handler.attach(engine) + engine.run(data, max_epochs=2) engine.run(data, max_epochs=5) for filename in filenames: self.assertTrue(os.path.exists(os.path.join(tempdir, filename)))