Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(framework) Refactor ClientApp loading to use explicit arguments #3805

Merged
merged 7 commits into from
Jul 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions src/py/flwr/client/supernode/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@ def run_supernode() -> None:
_warn_deprecated_server_arg(args)

root_certificates = _get_certificates(args)
load_fn = _get_load_client_app_fn(args, multi_app=True)
load_fn = _get_load_client_app_fn(
default_app_ref=getattr(args, "client-app"),
dir_arg=args.dir,
flwr_dir_arg=args.flwr_dir,
multi_app=True,
)
authentication_keys = _try_setup_client_authentication(args)

_start_client_internal(
Expand Down Expand Up @@ -93,7 +98,11 @@ def run_client_app() -> None:
_warn_deprecated_server_arg(args)

root_certificates = _get_certificates(args)
load_fn = _get_load_client_app_fn(args, multi_app=False)
load_fn = _get_load_client_app_fn(
default_app_ref=getattr(args, "client-app"),
dir_arg=args.dir,
multi_app=False,
)
authentication_keys = _try_setup_client_authentication(args)

_start_client_internal(
Expand Down Expand Up @@ -166,7 +175,10 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:


def _get_load_client_app_fn(
args: argparse.Namespace, multi_app: bool
default_app_ref: str,
dir_arg: str,
multi_app: bool,
flwr_dir_arg: Optional[str] = None,
) -> Callable[[str, str], ClientApp]:
"""Get the load_client_app_fn function.

Expand All @@ -178,25 +190,24 @@ def _get_load_client_app_fn(
loads a default ClientApp.
"""
# Find the Flower directory containing Flower Apps (only for multi-app)
flwr_dir = Path("")
if "flwr_dir" in args:
if args.flwr_dir is None:
if not multi_app:
flwr_dir = Path("")
else:
if flwr_dir_arg is None:
flwr_dir = get_flwr_dir()
else:
flwr_dir = Path(args.flwr_dir).absolute()
flwr_dir = Path(flwr_dir_arg).absolute()

inserted_path = None

default_app_ref: str = getattr(args, "client-app")

if not multi_app:
log(
DEBUG,
"Flower SuperNode will load and validate ClientApp `%s`",
getattr(args, "client-app"),
default_app_ref,
)
# Insert sys.path
dir_path = Path(args.dir).absolute()
dir_path = Path(dir_arg).absolute()
sys.path.insert(0, str(dir_path))
inserted_path = str(dir_path)

Expand All @@ -208,7 +219,7 @@ def _load(fab_id: str, fab_version: str) -> ClientApp:
# If multi-app feature is disabled
if not multi_app:
# Get sys path to be inserted
dir_path = Path(args.dir).absolute()
dir_path = Path(dir_arg).absolute()

# Set app reference
client_app_ref = default_app_ref
Expand All @@ -221,7 +232,7 @@ def _load(fab_id: str, fab_version: str) -> ClientApp:

log(WARN, "FAB ID is not provided; the default ClientApp will be loaded.")
# Get sys path to be inserted
dir_path = Path(args.dir).absolute()
dir_path = Path(dir_arg).absolute()

# Set app reference
client_app_ref = default_app_ref
Expand Down