Skip to content

Commit

Permalink
feat(framework) Enable configuring the simulation backend via `flwr r…
Browse files Browse the repository at this point in the history
…un` (#4059)
  • Loading branch information
jafermarq committed Aug 23, 2024
1 parent d2ffc14 commit cca0fab
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 5 deletions.
10 changes: 10 additions & 0 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Flower command line interface `run` command."""

import hashlib
import json
import subprocess
import sys
from logging import DEBUG
Expand Down Expand Up @@ -192,6 +193,8 @@ def _run_without_superexec(
) -> None:
try:
num_supernodes = federation_config["options"]["num-supernodes"]
verbose: Optional[bool] = federation_config["options"].get("verbose")
backend_cfg = federation_config["options"].get("backend", {})
except KeyError as err:
typer.secho(
"❌ The project's `pyproject.toml` needs to declare the number of"
Expand All @@ -212,6 +215,13 @@ def _run_without_superexec(
f"{num_supernodes}",
]

if backend_cfg:
# Stringify as JSON
command.extend(["--backend-config", json.dumps(backend_cfg)])

if verbose:
command.extend(["--verbose"])

if config_overrides:
command.extend(["--run-config", f"{' '.join(config_overrides)}"])

Expand Down
24 changes: 20 additions & 4 deletions src/py/flwr/simulation/run_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from logging import DEBUG, ERROR, INFO, WARNING
from pathlib import Path
from time import sleep
from typing import List, Optional
from typing import Any, List, Optional

from flwr.cli.config_utils import load_and_validate
from flwr.client import ClientApp
Expand Down Expand Up @@ -91,6 +91,17 @@ def _resolve_message(conflict_keys: List[str]) -> str:
return True


def _replace_keys(d: Any, match: str, target: str) -> Any:
if isinstance(d, dict):
return {
k.replace(match, target): _replace_keys(v, match, target)
for k, v in d.items()
}
if isinstance(d, list):
return [_replace_keys(i, match, target) for i in d]
return d


# Entry point from CLI
# pylint: disable=too-many-locals
def run_simulation_from_cli() -> None:
Expand All @@ -105,6 +116,14 @@ def run_simulation_from_cli() -> None:
code_example='TF_FORCE_GPU_ALLOW_GROWTH="true" flower-simulation <...>',
)

# Load JSON config
backend_config_dict = json.loads(args.backend_config)

if backend_config_dict:
# Backend config internally operates with `_` not with `-`
backend_config_dict = _replace_keys(backend_config_dict, match="-", target="_")
log(DEBUG, "backend_config_dict: %s", backend_config_dict)

# We are supporting two modes for the CLI entrypoint:
# 1) Running an app dir containing a `pyproject.toml`
# 2) Running any ClientApp and SeverApp w/o pyproject.toml being present
Expand Down Expand Up @@ -167,9 +186,6 @@ def run_simulation_from_cli() -> None:
override_config=override_config,
)

# Load JSON config
backend_config_dict = json.loads(args.backend_config)

_run_simulation(
server_app_attr=server_app_attr,
client_app_attr=client_app_attr,
Expand Down
21 changes: 20 additions & 1 deletion src/py/flwr/superexec/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Simulation engine executor."""


import json
import subprocess
import sys
from logging import ERROR, INFO, WARN
Expand All @@ -24,6 +25,7 @@

from flwr.cli.config_utils import load_and_validate
from flwr.cli.install import install_from_fab
from flwr.common.config import unflatten_dict
from flwr.common.constant import RUN_ID_NUM_BYTES
from flwr.common.logger import log
from flwr.common.typing import UserConfig
Expand Down Expand Up @@ -108,6 +110,7 @@ def set_config(
)
self.verbose = verbose

# pylint: disable=too-many-locals
@override
def start_run(
self,
Expand Down Expand Up @@ -152,6 +155,15 @@ def start_run(
"Config extracted from FAB's pyproject.toml is not valid"
)

# Flatten federated config
federation_config_flat = unflatten_dict(federation_config)

num_supernodes = federation_config_flat.get(
"num-supernodes", self.num_supernodes
)
backend_cfg = federation_config_flat.get("backend", {})
verbose: Optional[bool] = federation_config_flat.get("verbose")

# In Simulation there is no SuperLink, still we create a run_id
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
log(INFO, "Created run %s", str(run_id))
Expand All @@ -162,11 +174,18 @@ def start_run(
"--app",
f"{str(fab_path)}",
"--num-supernodes",
f"{federation_config.get('num-supernodes', self.num_supernodes)}",
f"{num_supernodes}",
"--run-id",
str(run_id),
]

if backend_cfg:
# Stringify as JSON
command.extend(["--backend-config", json.dumps(backend_cfg)])

if verbose:
command.extend(["--verbose"])

if override_config:
override_config_str = _user_config_to_str(override_config)
command.extend(["--run-config", f"{override_config_str}"])
Expand Down

0 comments on commit cca0fab

Please sign in to comment.