From 98b2c948592981ac2ca5984054359f4978b50bb7 Mon Sep 17 00:00:00 2001 From: Kevin Date: Wed, 13 Sep 2023 12:37:42 -0400 Subject: [PATCH 1/3] Add stats_sender to MonaiAlgo for FL stats Signed-off-by: Kevin --- monai/fl/client/monai_algo.py | 10 +++++++++- monai/fl/utils/constants.py | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py index 626bc9651d..312dc0fc64 100644 --- a/monai/fl/client/monai_algo.py +++ b/monai/fl/client/monai_algo.py @@ -14,7 +14,7 @@ import os import time from collections.abc import Mapping, MutableMapping -from typing import Any, cast +from typing import Any, Callable, cast import torch import torch.distributed as dist @@ -359,6 +359,7 @@ def __init__( eval_workflow_name: str = "train", train_workflow: BundleWorkflow | None = None, eval_workflow: BundleWorkflow | None = None, + stats_sender: Callable | None = None, ): self.logger = logger self.bundle_root = bundle_root @@ -390,6 +391,7 @@ def __init__( if not isinstance(eval_workflow, BundleWorkflow) or eval_workflow.get_workflow_type() is None: raise ValueError("train workflow must be BundleWorkflow and set type.") self.eval_workflow = eval_workflow + self.stats_sender = stats_sender self.app_root = "" self.filter_parser: ConfigParser | None = None @@ -478,6 +480,12 @@ def initialize(self, extra=None): if len(config_filter_files) > 0: self.filter_parser.read_config(config_filter_files) + # set stats sender for nvflare + self.stats_sender = extra.get(ExtraItems.STATS_SENDER, self.stats_sender) + if self.stats_sender is not None: + self.stats_sender.attach(self.trainer) + self.stats_sender.attach(self.evaluator) + # Get filters self.pre_filters = self.filter_parser.get_parsed_content( FiltersType.PRE_FILTERS, default=ConfigItem(None, FiltersType.PRE_FILTERS) diff --git a/monai/fl/utils/constants.py b/monai/fl/utils/constants.py index 3f229d6ecc..eda1a6b4f9 100644 --- a/monai/fl/utils/constants.py +++ b/monai/fl/utils/constants.py @@ -29,6 +29,7 @@ class ExtraItems(StrEnum): MODEL_TYPE = "fl_model_type" CLIENT_NAME = "fl_client_name" APP_ROOT = "fl_app_root" + STATS_SENDER = "fl_stats_sender" class FlPhase(StrEnum): From 83ab3379f664bf77b57ba83decca5c94377145a9 Mon Sep 17 00:00:00 2001 From: Kevin Date: Wed, 13 Sep 2023 15:56:07 -0400 Subject: [PATCH 2/3] remove stats_sender from init args Signed-off-by: Kevin --- monai/fl/client/monai_algo.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py index 312dc0fc64..b2415bdd27 100644 --- a/monai/fl/client/monai_algo.py +++ b/monai/fl/client/monai_algo.py @@ -359,7 +359,6 @@ def __init__( eval_workflow_name: str = "train", train_workflow: BundleWorkflow | None = None, eval_workflow: BundleWorkflow | None = None, - stats_sender: Callable | None = None, ): self.logger = logger self.bundle_root = bundle_root @@ -391,7 +390,7 @@ def __init__( if not isinstance(eval_workflow, BundleWorkflow) or eval_workflow.get_workflow_type() is None: raise ValueError("train workflow must be BundleWorkflow and set type.") self.eval_workflow = eval_workflow - self.stats_sender = stats_sender + self.stats_sender = None self.app_root = "" self.filter_parser: ConfigParser | None = None From 083e38862208d013a255429b3a2fe320462daeb8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Sep 2023 19:56:38 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/fl/client/monai_algo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py index b2415bdd27..9acf131bd9 100644 --- a/monai/fl/client/monai_algo.py +++ b/monai/fl/client/monai_algo.py @@ -14,7 +14,7 @@ import os import time from collections.abc import Mapping, MutableMapping -from typing import Any, Callable, cast +from typing import Any, cast import torch import torch.distributed as dist