Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Add node-config arg to SuperNode #3782

Merged
merged 15 commits into from
Jul 12, 2024
16 changes: 6 additions & 10 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class `flwr.client.Client` (default: None)
event(EventType.START_CLIENT_ENTER)
_start_client_internal(
server_address=server_address,
node_config={},
load_client_app_fn=None,
client_fn=client_fn,
client=client,
Expand All @@ -181,6 +182,7 @@ class `flwr.client.Client` (default: None)
def _start_client_internal(
*,
server_address: str,
node_config: Dict[str, str],
load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None,
client_fn: Optional[ClientFnExt] = None,
client: Optional[Client] = None,
Expand All @@ -193,7 +195,6 @@ def _start_client_internal(
] = None,
max_retries: Optional[int] = None,
max_wait_time: Optional[float] = None,
partition_id: Optional[int] = None,
flwr_dir: Optional[Path] = None,
) -> None:
"""Start a Flower client node which connects to a Flower server.
Expand All @@ -204,6 +205,8 @@ def _start_client_internal(
The IPv4 or IPv6 address of the server. If the Flower
server runs on the same machine on port 8080, then `server_address`
would be `"[::]:8080"`.
node_config: Dict[str, str]
The configuration of the node.
load_client_app_fn : Optional[Callable[[], ClientApp]] (default: None)
A function that can be used to load a `ClientApp` instance.
client_fn : Optional[ClientFnExt]
Expand Down Expand Up @@ -238,9 +241,6 @@ class `flwr.client.Client` (default: None)
The maximum duration before the client stops trying to
connect to the server in case of connection error.
If set to None, there is no limit to the total time.
partition_id: Optional[int] (default: None)
The data partition index associated with this node. Better suited for
prototyping purposes.
flwr_dir: Optional[Path] (default: None)
The fully resolved path containing installed Flower Apps.
"""
Expand Down Expand Up @@ -319,10 +319,6 @@ def _on_backoff(retry_state: RetryState) -> None:
on_backoff=_on_backoff,
)

# Empty dict (for now)
# This will be removed once users can pass node_config via flower-supernode
node_config: Dict[str, str] = {}

# NodeState gets initialized when the first connection is established
node_state: Optional[NodeState] = None

Expand Down Expand Up @@ -353,7 +349,7 @@ def _on_backoff(retry_state: RetryState) -> None:
node_state = NodeState(
node_id=-1,
node_config={},
partition_id=partition_id,
partition_id=None,
)
else:
# Call create_node fn to register node
Expand All @@ -365,7 +361,7 @@ def _on_backoff(retry_state: RetryState) -> None:
node_state = NodeState(
node_id=node_id,
node_config=node_config,
partition_id=partition_id,
partition_id=None,
)

app_state_tracker.register_signal_handler()
Expand Down
18 changes: 12 additions & 6 deletions src/py/flwr/client/supernode/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@

from flwr.client.client_app import ClientApp, LoadClientAppError
from flwr.common import EventType, event
from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir
from flwr.common.config import (
get_flwr_dir,
get_project_config,
get_project_dir,
parse_config_args,
)
from flwr.common.constant import (
TRANSPORT_TYPE_GRPC_ADAPTER,
TRANSPORT_TYPE_GRPC_RERE,
Expand Down Expand Up @@ -67,7 +72,7 @@ def run_supernode() -> None:
authentication_keys=authentication_keys,
max_retries=args.max_retries,
max_wait_time=args.max_wait_time,
partition_id=args.partition_id,
node_config=parse_config_args(args.node_config),
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
flwr_dir=get_flwr_dir(args.flwr_dir),
)

Expand All @@ -93,6 +98,7 @@ def run_client_app() -> None:

_start_client_internal(
server_address=args.superlink,
node_config=parse_config_args(args.node_config),
load_client_app_fn=load_fn,
transport=args.transport,
root_certificates=root_certificates,
Expand Down Expand Up @@ -389,11 +395,11 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
help="The SuperNode's public key (as a path str) to enable authentication.",
)
parser.add_argument(
"--partition-id",
"--node-config",
type=int,
help="The data partition index associated with this SuperNode. Better suited "
"for prototyping purposes where a SuperNode might only load a fraction of an "
"artificially partitioned dataset (e.g. using `flwr-datasets`)",
help="A comma separated list of key/value pairs (separated by `=`) to "
"configure the SuperNode. "
'E.g, `--node-config key1="value1",partition-id=0,num-partitions=100`',
)


Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,16 @@ def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, st


def parse_config_args(
config_overrides: Optional[str],
config: Optional[str],
separator: str = ",",
) -> Dict[str, str]:
"""Parse separator separated list of key-value pairs separated by '='."""
overrides: Dict[str, str] = {}

if config_overrides is None:
if config is None:
return overrides

overrides_list = config_overrides.split(separator)
overrides_list = config.split(separator)
if (
len(overrides_list) == 1
and "=" not in overrides_list
Expand Down