Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Enable configuring the simulation backend via flwr run #4059

Merged
merged 11 commits into from
Aug 23, 2024
10 changes: 10 additions & 0 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Flower command line interface `run` command."""

import hashlib
import json
import subprocess
import sys
from logging import DEBUG
Expand Down Expand Up @@ -192,6 +193,8 @@ def _run_without_superexec(
) -> None:
try:
num_supernodes = federation_config["options"]["num-supernodes"]
verbose: Optional[bool] = federation_config["options"].get("verbose")
backend_cfg = federation_config["options"].get("backend", {})
except KeyError as err:
typer.secho(
"❌ The project's `pyproject.toml` needs to declare the number of"
Expand All @@ -212,6 +215,13 @@ def _run_without_superexec(
f"{num_supernodes}",
]

if backend_cfg:
# Stringify as JSON
command.extend(["--backend-config", json.dumps(backend_cfg)])

if verbose:
command.extend(["--verbose"])

if config_overrides:
command.extend(["--run-config", f"{' '.join(config_overrides)}"])

Expand Down
24 changes: 20 additions & 4 deletions src/py/flwr/simulation/run_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from logging import DEBUG, ERROR, INFO, WARNING
from pathlib import Path
from time import sleep
from typing import List, Optional
from typing import Any, List, Optional

from flwr.cli.config_utils import load_and_validate
from flwr.client import ClientApp
Expand Down Expand Up @@ -91,6 +91,17 @@ def _resolve_message(conflict_keys: List[str]) -> str:
return True


def _replace_keys(d: Any, match: str, target: str) -> Any:
if isinstance(d, dict):
return {
k.replace(match, target): _replace_keys(v, match, target)
for k, v in d.items()
}
if isinstance(d, list):
return [_replace_keys(i, match, target) for i in d]
return d


# Entry point from CLI
# pylint: disable=too-many-locals
def run_simulation_from_cli() -> None:
Expand All @@ -105,6 +116,14 @@ def run_simulation_from_cli() -> None:
code_example='TF_FORCE_GPU_ALLOW_GROWTH="true" flower-simulation <...>',
)

# Load JSON config
backend_config_dict = json.loads(args.backend_config)

if backend_config_dict:
# Backend config internally operates with `_` not with `-`
backend_config_dict = _replace_keys(backend_config_dict, match="-", target="_")
log(DEBUG, "backend_config_dict: %s", backend_config_dict)

# We are supporting two modes for the CLI entrypoint:
# 1) Running an app dir containing a `pyproject.toml`
# 2) Running any ClientApp and SeverApp w/o pyproject.toml being present
Expand Down Expand Up @@ -167,9 +186,6 @@ def run_simulation_from_cli() -> None:
override_config=override_config,
)

# Load JSON config
backend_config_dict = json.loads(args.backend_config)

_run_simulation(
server_app_attr=server_app_attr,
client_app_attr=client_app_attr,
Expand Down
21 changes: 20 additions & 1 deletion src/py/flwr/superexec/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Simulation engine executor."""


import json
import subprocess
import sys
from logging import ERROR, INFO, WARN
Expand All @@ -24,6 +25,7 @@

from flwr.cli.config_utils import load_and_validate
from flwr.cli.install import install_from_fab
from flwr.common.config import unflatten_dict
from flwr.common.constant import RUN_ID_NUM_BYTES
from flwr.common.logger import log
from flwr.common.typing import UserConfig
Expand Down Expand Up @@ -108,6 +110,7 @@ def set_config(
)
self.verbose = verbose

# pylint: disable=too-many-locals
@override
def start_run(
self,
Expand Down Expand Up @@ -152,6 +155,15 @@ def start_run(
"Config extracted from FAB's pyproject.toml is not valid"
)

# Flatten federated config
federation_config_flat = unflatten_dict(federation_config)

num_supernodes = federation_config_flat.get(
"num-supernodes", self.num_supernodes
)
backend_cfg = federation_config_flat.get("backend", {})
verbose: Optional[bool] = federation_config_flat.get("verbose")

# In Simulation there is no SuperLink, still we create a run_id
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
log(INFO, "Created run %s", str(run_id))
Expand All @@ -162,11 +174,18 @@ def start_run(
"--app",
f"{str(fab_path)}",
"--num-supernodes",
f"{federation_config.get('num-supernodes', self.num_supernodes)}",
f"{num_supernodes}",
"--run-id",
str(run_id),
]

if backend_cfg:
# Stringify as JSON
command.extend(["--backend-config", json.dumps(backend_cfg)])

if verbose:
command.extend(["--verbose"])

if override_config:
override_config_str = _user_config_to_str(override_config)
command.extend(["--run-config", f"{override_config_str}"])
Expand Down