From f190074d47faaa0f4611bba899cbcebd74252c71 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Wed, 6 Dec 2023 12:20:18 +0100 Subject: [PATCH 1/2] Refactor flower-client internals --- src/py/flwr/client/app.py | 88 +++++++++++++++++++++++++++------ src/py/flwr/client/flower.py | 2 +- src/py/flwr/common/telemetry.py | 4 ++ 3 files changed, 79 insertions(+), 15 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 7ce7d51d3d4..9efb0748d9d 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -37,7 +37,7 @@ from flwr.common.logger import log, warn_experimental_feature from flwr.proto.task_pb2 import TaskIns, TaskRes -from .flower import load_callable +from .flower import load_flower_callable from .grpc_client.connection import grpc_connection from .grpc_rere_client.connection import grpc_request_response from .message_handler.message_handler import handle_control_message @@ -47,6 +47,8 @@ def run_client() -> None: """Run Flower client.""" + event(EventType.RUN_CLIENT_ENTER) + log(INFO, "Long-running Flower client starting") args = _parse_args_client().parse_args() @@ -80,16 +82,17 @@ def run_client() -> None: sys.path.insert(0, callable_dir) def _load() -> Flower: - flower: Flower = load_callable(args.callable) + flower: Flower = load_flower_callable(args.callable) return flower - return start_client( + _start_client_internal( server_address=args.server, - load_callable_fn=_load, + load_flower_callable_fn=_load, transport="grpc-rere", # Only root_certificates=root_certificates, insecure=args.insecure, ) + event(EventType.RUN_CLIENT_LEAVE) def _parse_args_client() -> argparse.ArgumentParser: @@ -149,7 +152,6 @@ def _check_actionable_client( def start_client( *, server_address: str, - load_callable_fn: Optional[Callable[[], Flower]] = None, client_fn: Optional[ClientFn] = None, client: Optional[Client] = None, grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, @@ -165,8 +167,6 @@ def start_client( The IPv4 or IPv6 address of the server. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"[::]:8080"`. - load_callable_fn : Optional[Callable[[], Flower]] (default: None) - ... client_fn : Optional[ClientFn] A callable that instantiates a Client. (default: None) client : Optional[flwr.client.Client] @@ -223,11 +223,73 @@ class `flwr.client.Client` (default: None) >>> ) """ event(EventType.START_CLIENT_ENTER) + _start_client_internal( + server_address=server_address, + load_flower_callable_fn=None, + client_fn=client_fn, + client=client, + grpc_max_message_length=grpc_max_message_length, + root_certificates=root_certificates, + insecure=insecure, + transport=transport, + ) + event(EventType.START_CLIENT_LEAVE) + +# pylint: disable=import-outside-toplevel +# pylint: disable=too-many-branches +# pylint: disable=too-many-locals +# pylint: disable=too-many-statements +def _start_client_internal( + *, + server_address: str, + load_flower_callable_fn: Optional[Callable[[], Flower]] = None, + client_fn: Optional[ClientFn] = None, + client: Optional[Client] = None, + grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, + root_certificates: Optional[Union[bytes, str]] = None, + insecure: Optional[bool] = None, + transport: Optional[str] = None, +) -> None: + """Start a Flower client node which connects to a Flower server. + + Parameters + ---------- + server_address : str + The IPv4 or IPv6 address of the server. If the Flower + server runs on the same machine on port 8080, then `server_address` + would be `"[::]:8080"`. + load_flower_callable_fn : Optional[Callable[[], Flower]] (default: None) + A function that can be used to load a `Flower` callable instance. + client_fn : Optional[ClientFn] + A callable that instantiates a Client. (default: None) + client : Optional[flwr.client.Client] + An implementation of the abstract base + class `flwr.client.Client` (default: None) + grpc_max_message_length : int (default: 536_870_912, this equals 512MB) + The maximum length of gRPC messages that can be exchanged with the + Flower server. The default should be sufficient for most models. + Users who train very large models might need to increase this + value. Note that the Flower server needs to be started with the + same value (see `flwr.server.start_server`), otherwise it will not + know about the increased limit and block larger messages. + root_certificates : Optional[Union[bytes, str]] (default: None) + The PEM-encoded root certificates as a byte string or a path string. + If provided, a secure connection using the certificates will be + established to an SSL-enabled Flower server. + insecure : bool (default: True) + Starts an insecure gRPC connection when True. Enables HTTPS connection + when False, using system certificates if `root_certificates` is None. + transport : Optional[str] (default: None) + Configure the transport layer. Allowed values: + - 'grpc-bidi': gRPC, bidirectional streaming + - 'grpc-rere': gRPC, request-response (experimental) + - 'rest': HTTP (experimental) + """ if insecure is None: insecure = root_certificates is None - if load_callable_fn is None: + if load_flower_callable_fn is None: _check_actionable_client(client, client_fn) if client_fn is None: @@ -246,11 +308,11 @@ def single_client_factory( def _load_app() -> Flower: return Flower(client_fn=client_fn) - load_callable_fn = _load_app + load_flower_callable_fn = _load_app else: - warn_experimental_feature("`load_callable_fn`") + warn_experimental_feature("`load_flower_callable_fn`") - # At this point, only `load_callable_fn` should be used + # At this point, only `load_flower_callable_fn` should be used # Both `client` and `client_fn` must not be used directly # Initialize connection context manager @@ -284,7 +346,7 @@ def _load_app() -> Flower: break # Load app - app: Flower = load_callable_fn() + app: Flower = load_flower_callable_fn() # Handle task message fwd_msg: Fwd = Fwd( @@ -311,8 +373,6 @@ def _load_app() -> Flower: ) time.sleep(sleep_duration) - event(EventType.START_CLIENT_LEAVE) - def start_numpy_client( *, diff --git a/src/py/flwr/client/flower.py b/src/py/flwr/client/flower.py index 5b083ee11b9..10c78ec45b4 100644 --- a/src/py/flwr/client/flower.py +++ b/src/py/flwr/client/flower.py @@ -72,7 +72,7 @@ class LoadCallableError(Exception): """.""" -def load_callable(module_attribute_str: str) -> Flower: +def load_flower_callable(module_attribute_str: str) -> Flower: """Load the `Flower` object specified in a module attribute string. The module/attribute string should have the form :. Valid diff --git a/src/py/flwr/common/telemetry.py b/src/py/flwr/common/telemetry.py index d56726d8337..1eed78786b7 100644 --- a/src/py/flwr/common/telemetry.py +++ b/src/py/flwr/common/telemetry.py @@ -152,6 +152,10 @@ def _generate_next_value_(name: str, start: int, count: int, last_values: List[A START_DRIVER_ENTER = auto() START_DRIVER_LEAVE = auto() + # Driver API and Fleet API + RUN_CLIENT_ENTER = auto() + RUN_CLIENT_LEAVE = auto() + # Use the ThreadPoolExecutor with max_workers=1 to have a queue # and also ensure that telemetry calls are not blocking. From f805da8e2ea7b618065a46aa8989f496f03e8f62 Mon Sep 17 00:00:00 2001 From: Taner Topal Date: Wed, 6 Dec 2023 12:29:27 +0100 Subject: [PATCH 2/2] Update src/py/flwr/common/telemetry.py --- src/py/flwr/common/telemetry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/common/telemetry.py b/src/py/flwr/common/telemetry.py index 1eed78786b7..fed8b5a978b 100644 --- a/src/py/flwr/common/telemetry.py +++ b/src/py/flwr/common/telemetry.py @@ -152,7 +152,7 @@ def _generate_next_value_(name: str, start: int, count: int, last_values: List[A START_DRIVER_ENTER = auto() START_DRIVER_LEAVE = auto() - # Driver API and Fleet API + # SuperNode: flower-client RUN_CLIENT_ENTER = auto() RUN_CLIENT_LEAVE = auto()