Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions monai/apps/auto3dseg/auto_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def __init__(
# determine if we need to analyze, algo_gen or train from cache, unless manually provided
self.analyze = not self.cache["analyze"] if analyze is None else analyze
self.algo_gen = not self.cache["algo_gen"] if algo_gen is None else algo_gen
self.train = not self.cache["train"] if train is None else train
self.train = train
self.ensemble = ensemble # last step, no need to check

self.set_training_params()
Expand Down Expand Up @@ -758,7 +758,8 @@ def run(self):
logger.info("Skipping algorithm generation...")

# step 3: algo training
if self.train:
auto_train_choice = self.train is None
if self.train or (auto_train_choice and not self.cache["train"]):
history = import_bundle_algo_history(self.work_dir, only_trained=False)

if len(history) == 0:
Expand All @@ -767,10 +768,15 @@ def run(self):
"Possibly the required algorithms generation step was not completed."
)

if not self.hpo:
self._train_algo_in_sequence(history)
else:
self._train_algo_in_nni(history)
if auto_train_choice:
history = [h for h in history if not h["is_trained"]] # skip trained

if len(history) > 0:
if not self.hpo:
self._train_algo_in_sequence(history)
else:
self._train_algo_in_nni(history)

self.export_cache(train=True)
else:
logger.info("Skipping algorithm training...")
Expand Down Expand Up @@ -798,4 +804,4 @@ def run(self):
self.save_image(pred)
logger.info(f"Auto3Dseg ensemble prediction outputs are saved in {self.output_dir}.")

logger.info("Auto3Dseg pipeline is complete successfully.")
logger.info("Auto3Dseg pipeline is completed successfully.")
9 changes: 6 additions & 3 deletions monai/apps/auto3dseg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@ def import_bundle_algo_history(
if isinstance(algo, BundleAlgo): # algo's template path needs override
algo.template_path = algo_meta_data["template_path"]

best_metrics = "best_metrics"
is_trained = best_metrics in algo_meta_data

if only_trained:
if "best_metrics" in algo_meta_data:
history.append({name: algo})
if is_trained:
history.append({name: algo, "is_trained": is_trained, best_metrics: algo_meta_data[best_metrics]})
else:
history.append({name: algo})
history.append({name: algo, "is_trained": is_trained, best_metrics: algo_meta_data.get(best_metrics, None)})

return history

Expand Down