diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py index b6cb503a92..0651c6ff33 100644 --- a/monai/handlers/checkpoint_saver.py +++ b/monai/handlers/checkpoint_saver.py @@ -12,6 +12,7 @@ from __future__ import annotations import logging +import os import warnings from collections.abc import Mapping from typing import TYPE_CHECKING, Any @@ -118,6 +119,7 @@ def __init__( self._key_metric_checkpoint: Checkpoint | None = None self._interval_checkpoint: Checkpoint | None = None self._name = name + self._final_filename = final_filename class _DiskSaver(DiskSaver): """ @@ -148,7 +150,7 @@ def _final_func(engine: Engine) -> Any: self._final_checkpoint = Checkpoint( to_save=self.save_dict, - save_handler=_DiskSaver(dirname=self.save_dir, filename=final_filename), + save_handler=_DiskSaver(dirname=self.save_dir, filename=self._final_filename), filename_prefix=file_prefix, score_function=_final_func, score_name="final_iteration", @@ -271,7 +273,11 @@ def completed(self, engine: Engine) -> None: raise AssertionError if not hasattr(self.logger, "info"): raise AssertionError("Error, provided logger has not info attribute.") - self.logger.info(f"Train completed, saved final checkpoint: {self._final_checkpoint.last_checkpoint}") + if self._final_filename is not None: + _final_checkpoint_path = os.path.join(self.save_dir, self._final_filename) + else: + _final_checkpoint_path = self._final_checkpoint.last_checkpoint # type: ignore[assignment] + self.logger.info(f"Train completed, saved final checkpoint: {_final_checkpoint_path}") def exception_raised(self, engine: Engine, e: Exception) -> None: """Callback for train or validation/evaluation exception raised Event. @@ -291,7 +297,11 @@ def exception_raised(self, engine: Engine, e: Exception) -> None: raise AssertionError if not hasattr(self.logger, "info"): raise AssertionError("Error, provided logger has not info attribute.") - self.logger.info(f"Exception raised, saved the last checkpoint: {self._final_checkpoint.last_checkpoint}") + if self._final_filename is not None: + _final_checkpoint_path = os.path.join(self.save_dir, self._final_filename) + else: + _final_checkpoint_path = self._final_checkpoint.last_checkpoint # type: ignore[assignment] + self.logger.info(f"Exception raised, saved the last checkpoint: {_final_checkpoint_path}") raise e def metrics_completed(self, engine: Engine) -> None: