diff --git a/monai/apps/auto3dseg/auto_runner.py b/monai/apps/auto3dseg/auto_runner.py index 57d3b6c014..fcd76e02a0 100644 --- a/monai/apps/auto3dseg/auto_runner.py +++ b/monai/apps/auto3dseg/auto_runner.py @@ -281,7 +281,7 @@ def __init__( # determine if we need to analyze, algo_gen or train from cache, unless manually provided self.analyze = not self.cache["analyze"] if analyze is None else analyze self.algo_gen = not self.cache["algo_gen"] if algo_gen is None else algo_gen - self.train = not self.cache["train"] if train is None else train + self.train = train self.ensemble = ensemble # last step, no need to check self.set_training_params() @@ -758,7 +758,8 @@ def run(self): logger.info("Skipping algorithm generation...") # step 3: algo training - if self.train: + auto_train_choice = self.train is None + if self.train or (auto_train_choice and not self.cache["train"]): history = import_bundle_algo_history(self.work_dir, only_trained=False) if len(history) == 0: @@ -767,10 +768,15 @@ def run(self): "Possibly the required algorithms generation step was not completed." ) - if not self.hpo: - self._train_algo_in_sequence(history) - else: - self._train_algo_in_nni(history) + if auto_train_choice: + history = [h for h in history if not h["is_trained"]] # skip trained + + if len(history) > 0: + if not self.hpo: + self._train_algo_in_sequence(history) + else: + self._train_algo_in_nni(history) + self.export_cache(train=True) else: logger.info("Skipping algorithm training...") @@ -798,4 +804,4 @@ def run(self): self.save_image(pred) logger.info(f"Auto3Dseg ensemble prediction outputs are saved in {self.output_dir}.") - logger.info("Auto3Dseg pipeline is complete successfully.") + logger.info("Auto3Dseg pipeline is completed successfully.") diff --git a/monai/apps/auto3dseg/utils.py b/monai/apps/auto3dseg/utils.py index 67cde64a2c..feadc08808 100644 --- a/monai/apps/auto3dseg/utils.py +++ b/monai/apps/auto3dseg/utils.py @@ -47,11 +47,14 @@ def import_bundle_algo_history( if isinstance(algo, BundleAlgo): # algo's template path needs override algo.template_path = algo_meta_data["template_path"] + best_metrics = "best_metrics" + is_trained = best_metrics in algo_meta_data + if only_trained: - if "best_metrics" in algo_meta_data: - history.append({name: algo}) + if is_trained: + history.append({name: algo, "is_trained": is_trained, best_metrics: algo_meta_data[best_metrics]}) else: - history.append({name: algo}) + history.append({name: algo, "is_trained": is_trained, best_metrics: algo_meta_data.get(best_metrics, None)}) return history