From b712d730d859b519917bb24d567e91c88e13ec0f Mon Sep 17 00:00:00 2001 From: myron Date: Tue, 4 Apr 2023 21:06:44 -0700 Subject: [PATCH 1/3] resuming training, skipping trained algos Signed-off-by: myron --- monai/apps/auto3dseg/auto_runner.py | 21 ++++++++++++++------- monai/apps/auto3dseg/utils.py | 9 ++++++--- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/monai/apps/auto3dseg/auto_runner.py b/monai/apps/auto3dseg/auto_runner.py index 57d3b6c014..2bae5bf893 100644 --- a/monai/apps/auto3dseg/auto_runner.py +++ b/monai/apps/auto3dseg/auto_runner.py @@ -281,7 +281,7 @@ def __init__( # determine if we need to analyze, algo_gen or train from cache, unless manually provided self.analyze = not self.cache["analyze"] if analyze is None else analyze self.algo_gen = not self.cache["algo_gen"] if algo_gen is None else algo_gen - self.train = not self.cache["train"] if train is None else train + self.train = train self.ensemble = ensemble # last step, no need to check self.set_training_params() @@ -758,7 +758,9 @@ def run(self): logger.info("Skipping algorithm generation...") # step 3: algo training - if self.train: + auto_train_choice = self.train is None + if self.train == True or (auto_train_choice and not self.cache["train"]): + history = import_bundle_algo_history(self.work_dir, only_trained=False) if len(history) == 0: @@ -767,10 +769,15 @@ def run(self): "Possibly the required algorithms generation step was not completed." ) - if not self.hpo: - self._train_algo_in_sequence(history) - else: - self._train_algo_in_nni(history) + if auto_train_choice: + history = [h for h in history if not h["is_trained"]] # skip trained + + if len(history) > 0: + if not self.hpo: + self._train_algo_in_sequence(history) + else: + self._train_algo_in_nni(history) + self.export_cache(train=True) else: logger.info("Skipping algorithm training...") @@ -798,4 +805,4 @@ def run(self): self.save_image(pred) logger.info(f"Auto3Dseg ensemble prediction outputs are saved in {self.output_dir}.") - logger.info("Auto3Dseg pipeline is complete successfully.") + logger.info("Auto3Dseg pipeline is completed successfully.") diff --git a/monai/apps/auto3dseg/utils.py b/monai/apps/auto3dseg/utils.py index 67cde64a2c..feadc08808 100644 --- a/monai/apps/auto3dseg/utils.py +++ b/monai/apps/auto3dseg/utils.py @@ -47,11 +47,14 @@ def import_bundle_algo_history( if isinstance(algo, BundleAlgo): # algo's template path needs override algo.template_path = algo_meta_data["template_path"] + best_metrics = "best_metrics" + is_trained = best_metrics in algo_meta_data + if only_trained: - if "best_metrics" in algo_meta_data: - history.append({name: algo}) + if is_trained: + history.append({name: algo, "is_trained": is_trained, best_metrics: algo_meta_data[best_metrics]}) else: - history.append({name: algo}) + history.append({name: algo, "is_trained": is_trained, best_metrics: algo_meta_data.get(best_metrics, None)}) return history From d67fe8c0c7d4f15dc29217e711f698b303d30302 Mon Sep 17 00:00:00 2001 From: myron Date: Wed, 5 Apr 2023 00:15:05 -0700 Subject: [PATCH 2/3] black Signed-off-by: myron --- monai/apps/auto3dseg/auto_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/apps/auto3dseg/auto_runner.py b/monai/apps/auto3dseg/auto_runner.py index 2bae5bf893..b56ce2dc81 100644 --- a/monai/apps/auto3dseg/auto_runner.py +++ b/monai/apps/auto3dseg/auto_runner.py @@ -760,7 +760,6 @@ def run(self): # step 3: algo training auto_train_choice = self.train is None if self.train == True or (auto_train_choice and not self.cache["train"]): - history = import_bundle_algo_history(self.work_dir, only_trained=False) if len(history) == 0: From 99baeae4d5126f85dcec2522a6febb264ed48155 Mon Sep 17 00:00:00 2001 From: Wenqi Li <831580+wyli@users.noreply.github.com> Date: Wed, 5 Apr 2023 09:01:04 +0100 Subject: [PATCH 3/3] Update monai/apps/auto3dseg/auto_runner.py Signed-off-by: Wenqi Li <831580+wyli@users.noreply.github.com> --- monai/apps/auto3dseg/auto_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/auto3dseg/auto_runner.py b/monai/apps/auto3dseg/auto_runner.py index b56ce2dc81..fcd76e02a0 100644 --- a/monai/apps/auto3dseg/auto_runner.py +++ b/monai/apps/auto3dseg/auto_runner.py @@ -759,7 +759,7 @@ def run(self): # step 3: algo training auto_train_choice = self.train is None - if self.train == True or (auto_train_choice and not self.cache["train"]): + if self.train or (auto_train_choice and not self.cache["train"]): history = import_bundle_algo_history(self.work_dir, only_trained=False) if len(history) == 0: