From 6dfe17a7bfda9b4ddfa19b60886b7fa14e816afd Mon Sep 17 00:00:00 2001 From: nvkevlu <55759229+nvkevlu@users.noreply.github.com> Date: Wed, 13 Sep 2023 17:35:16 -0400 Subject: [PATCH] Add stats_sender to MonaiAlgo for FL stats (#6984) PR #6220 was closed and NVFlareStatsHandler has now been implemented in NVFlare in https://github.com/NVIDIA/NVFlare/pull/1925. However, there is still the piece in MonaiAlgo to attach the stats_sender in initialize, so this PR adds that missing piece. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Kevin Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- monai/fl/client/monai_algo.py | 7 +++++++ monai/fl/utils/constants.py | 1 + 2 files changed, 8 insertions(+) diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py index 626bc9651d..9acf131bd9 100644 --- a/monai/fl/client/monai_algo.py +++ b/monai/fl/client/monai_algo.py @@ -390,6 +390,7 @@ def __init__( if not isinstance(eval_workflow, BundleWorkflow) or eval_workflow.get_workflow_type() is None: raise ValueError("train workflow must be BundleWorkflow and set type.") self.eval_workflow = eval_workflow + self.stats_sender = None self.app_root = "" self.filter_parser: ConfigParser | None = None @@ -478,6 +479,12 @@ def initialize(self, extra=None): if len(config_filter_files) > 0: self.filter_parser.read_config(config_filter_files) + # set stats sender for nvflare + self.stats_sender = extra.get(ExtraItems.STATS_SENDER, self.stats_sender) + if self.stats_sender is not None: + self.stats_sender.attach(self.trainer) + self.stats_sender.attach(self.evaluator) + # Get filters self.pre_filters = self.filter_parser.get_parsed_content( FiltersType.PRE_FILTERS, default=ConfigItem(None, FiltersType.PRE_FILTERS) diff --git a/monai/fl/utils/constants.py b/monai/fl/utils/constants.py index 3f229d6ecc..eda1a6b4f9 100644 --- a/monai/fl/utils/constants.py +++ b/monai/fl/utils/constants.py @@ -29,6 +29,7 @@ class ExtraItems(StrEnum): MODEL_TYPE = "fl_model_type" CLIENT_NAME = "fl_client_name" APP_ROOT = "fl_app_root" + STATS_SENDER = "fl_stats_sender" class FlPhase(StrEnum):