Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(framework:skip) Enable SuperNode to complete context registration when FAB is not installed #4049

Merged
merged 3 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,12 +452,14 @@ def _on_backoff(retry_state: RetryState) -> None:

# Register context for this run
node_state.register_context(
run_id=run_id, run=run, flwr_path=flwr_path
run_id=run_id,
run=run,
panh99 marked this conversation as resolved.
Show resolved Hide resolved
flwr_path=flwr_path,
fab=fab,
)

# Retrieve context for this run
context = node_state.retrieve_context(run_id=run_id)

# Create an error reply message that will never be used to prevent
# the used-before-assignment linting error
reply_message = message.create_error_reply(
Expand Down
21 changes: 17 additions & 4 deletions src/py/flwr/client/node_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@
from typing import Dict, Optional

from flwr.common import Context, RecordSet
from flwr.common.config import get_fused_config, get_fused_config_from_dir
from flwr.common.typing import Run, UserConfig
from flwr.common.config import (
get_fused_config,
get_fused_config_from_dir,
get_fused_config_from_fab,
)
from flwr.common.typing import Fab, Run, UserConfig


@dataclass()
Expand All @@ -44,12 +48,14 @@ def __init__(
self.node_config = node_config
self.run_infos: Dict[int, RunInfo] = {}

# pylint: disable=too-many-arguments
def register_context(
self,
run_id: int,
run: Optional[Run] = None,
flwr_path: Optional[Path] = None,
app_dir: Optional[str] = None,
fab: Optional[Fab] = None,
) -> None:
"""Register new run context for this node."""
if run_id not in self.run_infos:
Expand All @@ -65,8 +71,15 @@ def register_context(
else:
raise ValueError("The specified `app_dir` must be a directory.")
else:
# Load from .fab
initial_run_config = get_fused_config(run, flwr_path) if run else {}
if run:
if fab:
# Load pyproject.toml from FAB file and fuse
initial_run_config = get_fused_config_from_fab(fab.content, run)
else:
# Load pyproject.toml from installed FAB and fuse
initial_run_config = get_fused_config(run, flwr_path)
else:
initial_run_config = {}
panh99 marked this conversation as resolved.
Show resolved Hide resolved
self.run_infos[run_id] = RunInfo(
initial_run_config=initial_run_config,
context=Context(
Expand Down
14 changes: 13 additions & 1 deletion src/py/flwr/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import tomli

from flwr.cli.config_utils import validate_fields
from flwr.cli.config_utils import get_fab_config, validate_fields
from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME
from flwr.common.typing import Run, UserConfig, UserConfigValue

Expand Down Expand Up @@ -104,6 +104,18 @@ def get_fused_config_from_dir(
return fuse_dicts(flat_default_config, override_config)


def get_fused_config_from_fab(fab_file: Union[Path, bytes], run: Run) -> UserConfig:
"""Fuse default config in a `FAB` with overrides in a `Run`.

This enables obtaining a run-config without having to install the FAB. This
function mirrors `get_fused_config_from_dir`. This is useful when the execution
of the FAB is delegated to a different process.
"""
default_config = get_fab_config(fab_file)["tool"]["flwr"]["app"].get("config", {})
flat_config_flat = flatten_dict(default_config)
return fuse_dicts(flat_config_flat, run.override_config)


def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig:
"""Merge the overrides from a `Run` with the config from a FAB.

Expand Down