Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions monai/apps/auto3dseg/auto_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from monai.auto3dseg.utils import algo_to_pickle
from monai.bundle import ConfigParser
from monai.transforms import SaveImage
from monai.utils.enums import AlgoEnsembleKeys
from monai.utils.enums import AlgoKeys
from monai.utils.module import look_up_option, optional_import

logger = get_logger(module_name=__name__)
Expand Down Expand Up @@ -636,11 +636,11 @@ def _train_algo_in_sequence(self, history: list[dict[str, Any]]) -> None:
progress.yaml, accuracies in CSV and a pickle file of the Algo object.
"""
for algo_dict in history:
algo = algo_dict[AlgoEnsembleKeys.ALGO]
algo = algo_dict[AlgoKeys.ALGO]
algo.train(self.train_params)
acc = algo.get_score()

algo_meta_data = {str(AlgoEnsembleKeys.SCORE): acc}
algo_meta_data = {str(AlgoKeys.SCORE): acc}
algo_to_pickle(algo, template_path=algo.template_path, **algo_meta_data)

def _train_algo_in_nni(self, history: list[dict[str, Any]]) -> None:
Expand Down Expand Up @@ -675,8 +675,8 @@ def _train_algo_in_nni(self, history: list[dict[str, Any]]) -> None:
last_total_tasks = len(import_bundle_algo_history(self.work_dir, only_trained=True))
mode_dry_run = self.hpo_params.pop("nni_dry_run", False)
for algo_dict in history:
name = algo_dict[AlgoEnsembleKeys.ID]
algo = algo_dict[AlgoEnsembleKeys.ALGO]
name = algo_dict[AlgoKeys.ID]
algo = algo_dict[AlgoKeys.ALGO]
nni_gen = NNIGen(algo=algo, params=self.hpo_params)
obj_filename = nni_gen.get_obj_filename()
nni_config = deepcopy(default_nni_config)
Expand Down Expand Up @@ -772,13 +772,13 @@ def run(self):
)

if auto_train_choice:
skip_algos = [h[AlgoEnsembleKeys.ID] for h in history if h["is_trained"]]
skip_algos = [h[AlgoKeys.ID] for h in history if h[AlgoKeys.IS_TRAINED]]
if len(skip_algos) > 0:
logger.info(
f"Skipping already trained algos {skip_algos}."
"Set option train=True to always retrain all algos."
)
history = [h for h in history if not h["is_trained"]]
history = [h for h in history if not h[AlgoKeys.IS_TRAINED]]

if len(history) > 0:
if not self.hpo:
Expand All @@ -794,13 +794,13 @@ def run(self):
if self.ensemble:
history = import_bundle_algo_history(self.work_dir, only_trained=False)

history_untrained = [h for h in history if not h["is_trained"]]
history_untrained = [h for h in history if not h[AlgoKeys.IS_TRAINED]]
if len(history_untrained) > 0:
warnings.warn(
f"Ensembling step will skip {[h['name'] for h in history_untrained]} untrained algos."
"Generally it means these algos did not complete training."
)
history = [h for h in history if h["is_trained"]]
history = [h for h in history if h[AlgoKeys.IS_TRAINED]]

if len(history) == 0:
raise ValueError(
Expand All @@ -816,7 +816,7 @@ def run(self):
if len(preds) > 0:
logger.info("Auto3Dseg picked the following networks to ensemble:")
for algo in ensembler.get_algo_ensemble():
logger.info(algo[AlgoEnsembleKeys.ID])
logger.info(algo[AlgoKeys.ID])

for pred in preds:
self.save_image(pred)
Expand Down
4 changes: 2 additions & 2 deletions monai/apps/auto3dseg/bundle_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from monai.auto3dseg.utils import algo_to_pickle
from monai.bundle.config_parser import ConfigParser
from monai.utils import ensure_tuple
from monai.utils.enums import AlgoEnsembleKeys
from monai.utils.enums import AlgoKeys

logger = get_logger(module_name=__name__)
ALGO_HASH = os.environ.get("MONAI_ALGO_HASH", "7758ad1")
Expand Down Expand Up @@ -539,5 +539,5 @@ def generate(

algo_to_pickle(gen_algo, template_path=algo.template_path)
self.history.append(
{AlgoEnsembleKeys.ID: name, AlgoEnsembleKeys.ALGO: gen_algo}
{AlgoKeys.ID: name, AlgoKeys.ALGO: gen_algo}
) # track the previous, may create a persistent history
22 changes: 11 additions & 11 deletions monai/apps/auto3dseg/ensemble_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from monai.auto3dseg.utils import datafold_read
from monai.bundle import ConfigParser
from monai.transforms import MeanEnsemble, VoteEnsemble
from monai.utils.enums import AlgoEnsembleKeys
from monai.utils.enums import AlgoKeys
from monai.utils.misc import prob2class
from monai.utils.module import look_up_option

Expand Down Expand Up @@ -59,7 +59,7 @@ def get_algo(self, identifier):
identifier: the name of the bundleAlgo
"""
for algo in self.algos:
if identifier == algo[AlgoEnsembleKeys.ID]:
if identifier == algo[AlgoKeys.ID]:
return algo

def get_algo_ensemble(self):
Expand Down Expand Up @@ -160,7 +160,7 @@ def __call__(self, pred_param: dict[str, Any] | None = None) -> list[torch.Tenso
print(i)
preds = []
for algo in self.algo_ensemble:
infer_instance = algo[AlgoEnsembleKeys.ALGO]
infer_instance = algo[AlgoKeys.ALGO]
pred = infer_instance.predict(predict_files=[file], predict_params=param)
preds.append(pred[0])
outputs.append(self.ensemble_pred(preds, sigmoid=sigmoid))
Expand All @@ -187,7 +187,7 @@ def sort_score(self):
"""
Sort the best_metrics
"""
scores = concat_val_to_np(self.algos, [AlgoEnsembleKeys.SCORE])
scores = concat_val_to_np(self.algos, [AlgoKeys.SCORE])
return np.argsort(scores).tolist()

def collect_algos(self, n_best: int = -1) -> None:
Expand Down Expand Up @@ -238,14 +238,14 @@ def collect_algos(self) -> None:
best_model: BundleAlgo | None = None
for algo in self.algos:
# algorithm folder: {net}_{fold_index}_{other}
identifier = algo[AlgoEnsembleKeys.ID].split("_")[1]
identifier = algo[AlgoKeys.ID].split("_")[1]
try:
algo_id = int(identifier)
except ValueError as err:
raise ValueError(f"model identifier {identifier} is not number.") from err
if algo_id == f_idx and algo[AlgoEnsembleKeys.SCORE] > best_score:
if algo_id == f_idx and algo[AlgoKeys.SCORE] > best_score:
best_model = algo
best_score = algo[AlgoEnsembleKeys.SCORE]
best_score = algo[AlgoKeys.SCORE]
self.algo_ensemble.append(best_model)


Expand All @@ -268,7 +268,7 @@ class AlgoEnsembleBuilder:
"""

def __init__(self, history: Sequence[dict[str, Any]], data_src_cfg_filename: str | None = None):
self.infer_algos: list[dict[AlgoEnsembleKeys, Any]] = []
self.infer_algos: list[dict[AlgoKeys, Any]] = []
self.ensemble: AlgoEnsemble
self.data_src_cfg = ConfigParser(globals=False)

Expand All @@ -278,8 +278,8 @@ def __init__(self, history: Sequence[dict[str, Any]], data_src_cfg_filename: str
for algo_dict in history:
# load inference_config_paths

name = algo_dict[AlgoEnsembleKeys.ID]
gen_algo = algo_dict[AlgoEnsembleKeys.ALGO]
name = algo_dict[AlgoKeys.ID]
gen_algo = algo_dict[AlgoKeys.ALGO]

best_metric = gen_algo.get_score()
algo_path = gen_algo.output_path
Expand All @@ -306,7 +306,7 @@ def add_inferer(self, identifier: str, gen_algo: BundleAlgo, best_metric: float
if best_metric is None:
raise ValueError("Feature to re-validate is to be implemented")

algo = {AlgoEnsembleKeys.ID: identifier, AlgoEnsembleKeys.ALGO: gen_algo, AlgoEnsembleKeys.SCORE: best_metric}
algo = {AlgoKeys.ID: identifier, AlgoKeys.ALGO: gen_algo, AlgoKeys.SCORE: best_metric}
self.infer_algos.append(algo)

def set_ensemble_method(self, ensemble: AlgoEnsemble, *args: Any, **kwargs: Any) -> None:
Expand Down
10 changes: 5 additions & 5 deletions monai/apps/auto3dseg/hpo_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from monai.bundle.config_parser import ConfigParser
from monai.config import PathLike
from monai.utils import optional_import
from monai.utils.enums import AlgoEnsembleKeys
from monai.utils.enums import AlgoKeys

nni, has_nni = optional_import("nni")
optuna, has_optuna = optional_import("optuna")
Expand Down Expand Up @@ -99,8 +99,8 @@ class NNIGen(HPOGen):
# Bundle Algorithms are already generated by BundleGen in work_dir
import_bundle_algo_history(work_dir, only_trained=False)
algo_dict = self.history[0] # pick the first algorithm
algo_name = algo_dict[AlgoEnsembleKeys.ID]
onealgo = algo_dict[AlgoEnsembleKeys.ALGO]
algo_name = algo_dict[AlgoKeys.ID]
onealgo = algo_dict[AlgoKeys.ALGO]
nni_gen = NNIGen(algo=onealgo)
nni_gen.print_bundle_algo_instruction()

Expand Down Expand Up @@ -238,7 +238,7 @@ def run_algo(self, obj_filename: str, output_folder: str = ".", template_path: P
self.algo.train(self.params)
# step 4 report validation acc to controller
acc = self.algo.get_score()
algo_meta_data = {str(AlgoEnsembleKeys.SCORE): acc}
algo_meta_data = {str(AlgoKeys.SCORE): acc}

if isinstance(self.algo, BundleAlgo):
algo_to_pickle(self.algo, template_path=self.algo.template_path, **algo_meta_data)
Expand Down Expand Up @@ -411,7 +411,7 @@ def run_algo(self, obj_filename: str, output_folder: str = ".", template_path: P
self.algo.train(self.params)
# step 4 report validation acc to controller
acc = self.algo.get_score()
algo_meta_data = {str(AlgoEnsembleKeys.SCORE): acc}
algo_meta_data = {str(AlgoKeys.SCORE): acc}
if isinstance(self.algo, BundleAlgo):
algo_to_pickle(self.algo, template_path=self.algo.template_path, **algo_meta_data)
else:
Expand Down
13 changes: 4 additions & 9 deletions monai/apps/auto3dseg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from monai.apps.auto3dseg.bundle_gen import BundleAlgo
from monai.auto3dseg import algo_from_pickle, algo_to_pickle
from monai.utils.enums import AlgoEnsembleKeys
from monai.utils.enums import AlgoKeys


def import_bundle_algo_history(
Expand Down Expand Up @@ -49,17 +49,12 @@ def import_bundle_algo_history(
if isinstance(algo, BundleAlgo): # algo's template path needs override
algo.template_path = algo_meta_data["template_path"]

best_metric = algo_meta_data.get(AlgoEnsembleKeys.SCORE, None)
best_metric = algo_meta_data.get(AlgoKeys.SCORE, None)
is_trained = best_metric is not None

if (only_trained and is_trained) or not only_trained:
history.append(
{
AlgoEnsembleKeys.ID: name,
AlgoEnsembleKeys.ALGO: algo,
AlgoEnsembleKeys.SCORE: best_metric,
"is_trained": is_trained,
}
{AlgoKeys.ID: name, AlgoKeys.ALGO: algo, AlgoKeys.SCORE: best_metric, AlgoKeys.IS_TRAINED: is_trained}
)

return history
Expand All @@ -73,5 +68,5 @@ def export_bundle_algo_history(history: list[dict[str, BundleAlgo]]) -> None:
history: a List of Bundle. Typically, the history can be obtained from BundleGen get_history method
"""
for algo_dict in history:
algo = algo_dict[AlgoEnsembleKeys.ALGO]
algo = algo_dict[AlgoKeys.ALGO]
algo_to_pickle(algo, template_path=algo.template_path)
19 changes: 19 additions & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import random
from enum import Enum

from monai.utils import deprecated

__all__ = [
"StrEnum",
"NumpyPadMode",
Expand Down Expand Up @@ -56,6 +58,7 @@
"LazyAttr",
"BundleProperty",
"BundlePropertyConfig",
"AlgoKeys",
]


Expand Down Expand Up @@ -592,6 +595,7 @@ class LabelStatsKeys(StrEnum):
LABEL_NCOMP = "ncomponents"


@deprecated(since="1.2", msg_suffix="please use `AlgoKeys` instead.")
class AlgoEnsembleKeys(StrEnum):
"""
Default keys for Mixed Ensemble
Expand Down Expand Up @@ -664,3 +668,18 @@ class BundlePropertyConfig(StrEnum):

ID = "id"
REF_ID = "refer_id"


class AlgoKeys(StrEnum):
"""
Default keys for templated Auto3DSeg Algo.
`ID` is the identifier of the algorithm. The string has the format of <name>_<idx>_<other>.
`ALGO` is the Auto3DSeg Algo instance.
`IS_TRAINED` is the status that shows if the Algo has been trained.
`SCORE` is the score the Algo has achieved after training.
"""

ID = "identifier"
ALGO = "algo_instance"
IS_TRAINED = "is_trained"
SCORE = "best_metric"
10 changes: 5 additions & 5 deletions tests/test_auto3dseg_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from monai.bundle.config_parser import ConfigParser
from monai.data import create_test_image_3d
from monai.utils import optional_import, set_determinism
from monai.utils.enums import AlgoEnsembleKeys
from monai.utils.enums import AlgoKeys
from tests.utils import (
SkipIfBeforePyTorchVersion,
get_testing_algo_template_path,
Expand Down Expand Up @@ -135,8 +135,8 @@ def test_ensemble(self) -> None:
history = bundle_generator.get_history()

for algo_dict in history:
name = algo_dict[AlgoEnsembleKeys.ID]
algo = algo_dict[AlgoEnsembleKeys.ALGO]
name = algo_dict[AlgoKeys.ID]
algo = algo_dict[AlgoKeys.ALGO]
_train_param = train_param.copy()
if name.startswith("segresnet"):
_train_param["network#init_filters"] = 8
Expand All @@ -148,7 +148,7 @@ def test_ensemble(self) -> None:
builder = AlgoEnsembleBuilder(history, data_src_cfg)
builder.set_ensemble_method(AlgoEnsembleBestN(n_best=1))
ensemble = builder.get_ensemble()
name = ensemble.get_algo_ensemble()[0][AlgoEnsembleKeys.ID]
name = ensemble.get_algo_ensemble()[0][AlgoKeys.ID]
if name.startswith("segresnet"):
pred_param["network#init_filters"] = 8
elif name.startswith("swinunetr"):
Expand All @@ -159,7 +159,7 @@ def test_ensemble(self) -> None:
builder.set_ensemble_method(AlgoEnsembleBestByFold(1))
ensemble = builder.get_ensemble()
for algo in ensemble.get_algo_ensemble():
print(algo[AlgoEnsembleKeys.ID])
print(algo[AlgoKeys.ID])

def tearDown(self) -> None:
set_determinism(None)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_auto3dseg_hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from monai.bundle.config_parser import ConfigParser
from monai.data import create_test_image_3d
from monai.utils import optional_import
from monai.utils.enums import AlgoEnsembleKeys
from monai.utils.enums import AlgoKeys
from tests.utils import (
SkipIfBeforePyTorchVersion,
get_testing_algo_template_path,
Expand Down Expand Up @@ -140,7 +140,7 @@ def setUp(self) -> None:
@skip_if_no_cuda
def test_run_algo(self) -> None:
algo_dict = self.history[0]
algo = algo_dict[AlgoEnsembleKeys.ALGO]
algo = algo_dict[AlgoKeys.ALGO]
nni_gen = NNIGen(algo=algo, params=override_param)
obj_filename = nni_gen.get_obj_filename()
# this function will be used in HPO via Python Fire
Expand All @@ -150,7 +150,7 @@ def test_run_algo(self) -> None:
@skip_if_no_optuna
def test_run_optuna(self) -> None:
algo_dict = self.history[0]
algo = algo_dict[AlgoEnsembleKeys.ALGO]
algo = algo_dict[AlgoKeys.ALGO]

class OptunaGenLearningRate(OptunaGen):
def get_hyperparameters(self):
Expand All @@ -172,7 +172,7 @@ def get_hyperparameters(self):
@skip_if_no_cuda
def test_get_history(self) -> None:
algo_dict = self.history[0]
algo = algo_dict[AlgoEnsembleKeys.ALGO]
algo = algo_dict[AlgoKeys.ALGO]
nni_gen = NNIGen(algo=algo, params=override_param)
obj_filename = nni_gen.get_obj_filename()

Expand Down
6 changes: 3 additions & 3 deletions tests/test_integration_gpu_customization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from monai.bundle.config_parser import ConfigParser
from monai.data import create_test_image_3d
from monai.utils import optional_import
from monai.utils.enums import AlgoEnsembleKeys
from monai.utils.enums import AlgoKeys
from tests.utils import (
SkipIfBeforePyTorchVersion,
get_testing_algo_template_path,
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_ensemble_gpu_customization(self) -> None:
history = bundle_generator.get_history()

for algo_dict in history:
algo = algo_dict[AlgoEnsembleKeys.ALGO]
algo = algo_dict[AlgoKeys.ALGO]
algo.train(train_param)

builder = AlgoEnsembleBuilder(history, data_src_cfg)
Expand All @@ -151,7 +151,7 @@ def test_ensemble_gpu_customization(self) -> None:
builder.set_ensemble_method(AlgoEnsembleBestByFold(1))
ensemble = builder.get_ensemble()
for algo in ensemble.get_algo_ensemble():
print(algo[AlgoEnsembleKeys.ID])
print(algo[AlgoKeys.ID])

def tearDown(self) -> None:
self.test_dir.cleanup()
Expand Down