Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions monai/handlers/mlflow_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
from monai.utils import ensure_tuple, min_version, optional_import

Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
mlflow, _ = optional_import("mlflow")
mlflow.entities, _ = optional_import("mlflow.entities")
mlflow, _ = optional_import("mlflow", descriptor="Please install mlflow before using MLFlowHandler.")
mlflow.entities, _ = optional_import(
"mlflow.entities", descriptor="Please install mlflow.entities before using MLFlowHandler."
)

if TYPE_CHECKING:
from ignite.engine import Engine
Expand Down Expand Up @@ -76,21 +78,23 @@ class MLFlowHandler:
The default behavior is to track loss from output[0] as output is a decollated list
and we replicated loss value for every item of the decollated list.
`engine.state` and `output_transform` inherit from the ignite concept:
https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
https://pytorch-ignite.ai/concepts/03-state/, explanation and usage example are in the tutorial:
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
global_epoch_transform: a callable that is used to customize global epoch number.
For example, in evaluation, the evaluator engine might want to track synced epoch number
with the trainer engine.
state_attributes: expected attributes from `engine.state`, if provided, will extract them
when epoch completed.
tag_name: when iteration output is a scalar, `tag_name` is used to track, defaults to `'Loss'`.
experiment_name: name for an experiment, defaults to `default_experiment`.
run_name: name for run in an experiment.
experiment_param: a dict recording parameters which will not change through whole experiment,
experiment_name: the experiment name of MLflow, default to `'monai_experiment'`. An experiment can be
used to record several runs.
run_name: the run name in an experiment. A run can be used to record information about a workflow,
like the loss, metrics and so on.
experiment_param: a dict recording parameters which will not change through the whole workflow,
like torch version, cuda version and so on.
artifacts: paths to images that need to be recorded after a whole run.
optimizer_param_names: parameters' name in optimizer that need to be record during running,
defaults to "lr".
artifacts: paths to images that need to be recorded after running the workflow.
optimizer_param_names: parameter names in the optimizer that need to be recorded during running the
workflow, default to `'lr'`.
close_on_complete: whether to close the mlflow run in `complete` phase in workflow, default to False.

For more details of MLFlow usage, please refer to: https://mlflow.org/docs/latest/index.html.
Expand Down Expand Up @@ -132,6 +136,7 @@ def __init__(
self.artifacts = ensure_tuple(artifacts)
self.optimizer_param_names = ensure_tuple(optimizer_param_names)
self.client = mlflow.MlflowClient(tracking_uri=tracking_uri if tracking_uri else None)
self.run_finish_status = mlflow.entities.RunStatus.to_string(mlflow.entities.RunStatus.FINISHED)
self.close_on_complete = close_on_complete
self.experiment = None
self.cur_run = None
Expand Down Expand Up @@ -191,6 +196,8 @@ def start(self, engine: Engine) -> None:
run_name = f"run_{time.strftime('%Y%m%d_%H%M%S')}" if self.run_name is None else self.run_name
runs = self.client.search_runs(self.experiment.experiment_id)
runs = [r for r in runs if r.info.run_name == run_name or not self.run_name]
# runs marked as finish should not record info any more
runs = [r for r in runs if r.info.status != self.run_finish_status]
if runs:
self.cur_run = self.client.get_run(runs[-1].info.run_id) # pick latest active run
else:
Expand Down Expand Up @@ -264,8 +271,7 @@ def close(self) -> None:

"""
if self.cur_run:
status = mlflow.entities.RunStatus.to_string(mlflow.entities.RunStatus.FINISHED)
self.client.set_terminated(self.cur_run.info.run_id, status)
self.client.set_terminated(self.cur_run.info.run_id, self.run_finish_status)
self.cur_run = None

def epoch_completed(self, engine: Engine) -> None:
Expand Down
9 changes: 8 additions & 1 deletion tests/test_auto3dseg_bundlegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@
from monai.bundle.config_parser import ConfigParser
from monai.data import create_test_image_3d
from monai.utils import set_determinism
from tests.utils import get_testing_algo_template_path, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick
from tests.utils import (
SkipIfBeforePyTorchVersion,
get_testing_algo_template_path,
skip_if_downloading_fails,
skip_if_no_cuda,
skip_if_quick,
)

num_images_perfold = max(torch.cuda.device_count(), 4)
num_images_per_batch = 2
Expand Down Expand Up @@ -97,6 +103,7 @@ def run_auto3dseg_before_bundlegen(test_path, work_dir):


@skip_if_no_cuda
@SkipIfBeforePyTorchVersion((1, 11, 1))
@skip_if_quick
class TestBundleGen(unittest.TestCase):
def setUp(self) -> None:
Expand Down
33 changes: 33 additions & 0 deletions tests/test_handler_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,39 @@ def tearDown(self):
if tmpdir and os.path.exists(tmpdir):
shutil.rmtree(tmpdir)

def test_multi_run(self):
with tempfile.TemporaryDirectory() as tempdir:
# set up the train function for engine
def _train_func(engine, batch):
return [batch + 1.0]

# create and run an engine several times to get several runs
create_engine_times = 3
for _ in range(create_engine_times):
engine = Engine(_train_func)

@engine.on(Events.EPOCH_COMPLETED)
def _update_metric(engine):
current_metric = engine.state.metrics.get("acc", 0.1)
engine.state.metrics["acc"] = current_metric + 0.1
engine.state.test = current_metric

# set up testing handler
test_path = os.path.join(tempdir, "mlflow_test")
handler = MLFlowHandler(
iteration_log=False,
epoch_log=True,
tracking_uri=path_to_uri(test_path),
state_attributes=["test"],
close_on_complete=True,
)
handler.attach(engine)
engine.run(range(3), max_epochs=2)
run_cnt = len(handler.client.search_runs(handler.experiment.experiment_id))
handler.close()
# the run count should equal to the times of creating engine
self.assertEqual(create_engine_times, run_cnt)

def test_metrics_track(self):
experiment_param = {"backbone": "efficientnet_b0"}
with tempfile.TemporaryDirectory() as tempdir:
Expand Down