From 88a4a41a6201d03558e6cf44c5af595fbd1aa2cf Mon Sep 17 00:00:00 2001 From: Mingxin Date: Fri, 12 Apr 2024 04:13:25 +0000 Subject: [PATCH 1/2] Add checks for num_fold and fail early if wrong Signed-off-by: Mingxin --- monai/apps/auto3dseg/auto_runner.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/monai/apps/auto3dseg/auto_runner.py b/monai/apps/auto3dseg/auto_runner.py index 52a0824227..d3e2b3a2f6 100644 --- a/monai/apps/auto3dseg/auto_runner.py +++ b/monai/apps/auto3dseg/auto_runner.py @@ -298,9 +298,11 @@ def __init__( pass # inspect and update folds - num_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename) + self.max_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename) if "num_fold" in self.data_src_cfg: num_fold = int(self.data_src_cfg["num_fold"]) # override from config + else: + num_fold = self.max_fold self.data_src_cfg["datalist"] = datalist_filename # update path to a version in work_dir and save user input ConfigParser.export_config_file( @@ -399,6 +401,9 @@ def inspect_datalist_folds(self, datalist_filename: str) -> int: if len(fold_list) > 0: num_fold = max(fold_list) + 1 logger.info(f"Setting num_fold {num_fold} based on the input datalist {datalist_filename}.") + # check if every fold is present + if len(set(fold_list)) != num_fold: + raise ValueError(f"Fold numbers are not continuous from 0 to {num_fold - 1}") elif "validation" in datalist and len(datalist["validation"]) > 0: logger.info("No fold numbers provided, attempting to use a single fold based on the validation key") # update the datalist file @@ -492,6 +497,11 @@ def set_num_fold(self, num_fold: int = 5) -> AutoRunner: if num_fold <= 0: raise ValueError(f"num_fold is expected to be an integer greater than zero. Now it gets {num_fold}") + if num_fold > self.max_fold + 1: + # Auto3DSeg allows no validation set, so the maximum fold number is max_fold + 1 + raise ValueError( + f"num_fold is greater than the maximum fold number {self.max_fold} in {self.datalist_filename}." + ) self.num_fold = num_fold return self From d6b8ffa6278b2f35d58543b8b4e371821d4793eb Mon Sep 17 00:00:00 2001 From: Mingxin Date: Fri, 12 Apr 2024 04:19:20 +0000 Subject: [PATCH 2/2] improve log Signed-off-by: Mingxin --- monai/apps/auto3dseg/auto_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/apps/auto3dseg/auto_runner.py b/monai/apps/auto3dseg/auto_runner.py index d3e2b3a2f6..05c961f999 100644 --- a/monai/apps/auto3dseg/auto_runner.py +++ b/monai/apps/auto3dseg/auto_runner.py @@ -301,8 +301,10 @@ def __init__( self.max_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename) if "num_fold" in self.data_src_cfg: num_fold = int(self.data_src_cfg["num_fold"]) # override from config + logger.info(f"Setting num_fold {num_fold} based on the input config.") else: num_fold = self.max_fold + logger.info(f"Setting num_fold {num_fold} based on the input datalist {datalist_filename}.") self.data_src_cfg["datalist"] = datalist_filename # update path to a version in work_dir and save user input ConfigParser.export_config_file( @@ -400,7 +402,7 @@ def inspect_datalist_folds(self, datalist_filename: str) -> int: if len(fold_list) > 0: num_fold = max(fold_list) + 1 - logger.info(f"Setting num_fold {num_fold} based on the input datalist {datalist_filename}.") + logger.info(f"Found num_fold {num_fold} based on the input datalist {datalist_filename}.") # check if every fold is present if len(set(fold_list)) != num_fold: raise ValueError(f"Fold numbers are not continuous from 0 to {num_fold - 1}")