diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py index 626bc9651d..9acf131bd9 100644 --- a/monai/fl/client/monai_algo.py +++ b/monai/fl/client/monai_algo.py @@ -390,6 +390,7 @@ def __init__( if not isinstance(eval_workflow, BundleWorkflow) or eval_workflow.get_workflow_type() is None: raise ValueError("train workflow must be BundleWorkflow and set type.") self.eval_workflow = eval_workflow + self.stats_sender = None self.app_root = "" self.filter_parser: ConfigParser | None = None @@ -478,6 +479,12 @@ def initialize(self, extra=None): if len(config_filter_files) > 0: self.filter_parser.read_config(config_filter_files) + # set stats sender for nvflare + self.stats_sender = extra.get(ExtraItems.STATS_SENDER, self.stats_sender) + if self.stats_sender is not None: + self.stats_sender.attach(self.trainer) + self.stats_sender.attach(self.evaluator) + # Get filters self.pre_filters = self.filter_parser.get_parsed_content( FiltersType.PRE_FILTERS, default=ConfigItem(None, FiltersType.PRE_FILTERS) diff --git a/monai/fl/utils/constants.py b/monai/fl/utils/constants.py index 3f229d6ecc..eda1a6b4f9 100644 --- a/monai/fl/utils/constants.py +++ b/monai/fl/utils/constants.py @@ -29,6 +29,7 @@ class ExtraItems(StrEnum): MODEL_TYPE = "fl_model_type" CLIENT_NAME = "fl_client_name" APP_ROOT = "fl_app_root" + STATS_SENDER = "fl_stats_sender" class FlPhase(StrEnum):