Skip to content

Commit

Permalink
refactor(framework) Refactor ClientApp loading to use explicit argu…
Browse files Browse the repository at this point in the history
…ments (#3805)
  • Loading branch information
jafermarq authored and chongshenng committed Jul 16, 2024
1 parent dd37449 commit 0f7c0f7
Showing 1 changed file with 24 additions and 13 deletions.
37 changes: 24 additions & 13 deletions src/py/flwr/client/supernode/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@ def run_supernode() -> None:
_warn_deprecated_server_arg(args)

root_certificates = _get_certificates(args)
load_fn = _get_load_client_app_fn(args, multi_app=True)
load_fn = _get_load_client_app_fn(
default_app_ref=getattr(args, "client-app"),
dir_arg=args.dir,
flwr_dir_arg=args.flwr_dir,
multi_app=True,
)
authentication_keys = _try_setup_client_authentication(args)

_start_client_internal(
Expand Down Expand Up @@ -93,7 +98,11 @@ def run_client_app() -> None:
_warn_deprecated_server_arg(args)

root_certificates = _get_certificates(args)
load_fn = _get_load_client_app_fn(args, multi_app=False)
load_fn = _get_load_client_app_fn(
default_app_ref=getattr(args, "client-app"),
dir_arg=args.dir,
multi_app=False,
)
authentication_keys = _try_setup_client_authentication(args)

_start_client_internal(
Expand Down Expand Up @@ -166,7 +175,10 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:


def _get_load_client_app_fn(
args: argparse.Namespace, multi_app: bool
default_app_ref: str,
dir_arg: str,
multi_app: bool,
flwr_dir_arg: Optional[str] = None,
) -> Callable[[str, str], ClientApp]:
"""Get the load_client_app_fn function.
Expand All @@ -178,25 +190,24 @@ def _get_load_client_app_fn(
loads a default ClientApp.
"""
# Find the Flower directory containing Flower Apps (only for multi-app)
flwr_dir = Path("")
if "flwr_dir" in args:
if args.flwr_dir is None:
if not multi_app:
flwr_dir = Path("")
else:
if flwr_dir_arg is None:
flwr_dir = get_flwr_dir()
else:
flwr_dir = Path(args.flwr_dir).absolute()
flwr_dir = Path(flwr_dir_arg).absolute()

inserted_path = None

default_app_ref: str = getattr(args, "client-app")

if not multi_app:
log(
DEBUG,
"Flower SuperNode will load and validate ClientApp `%s`",
getattr(args, "client-app"),
default_app_ref,
)
# Insert sys.path
dir_path = Path(args.dir).absolute()
dir_path = Path(dir_arg).absolute()
sys.path.insert(0, str(dir_path))
inserted_path = str(dir_path)

Expand All @@ -208,7 +219,7 @@ def _load(fab_id: str, fab_version: str) -> ClientApp:
# If multi-app feature is disabled
if not multi_app:
# Get sys path to be inserted
dir_path = Path(args.dir).absolute()
dir_path = Path(dir_arg).absolute()

# Set app reference
client_app_ref = default_app_ref
Expand All @@ -221,7 +232,7 @@ def _load(fab_id: str, fab_version: str) -> ClientApp:

log(WARN, "FAB ID is not provided; the default ClientApp will be loaded.")
# Get sys path to be inserted
dir_path = Path(args.dir).absolute()
dir_path = Path(dir_arg).absolute()

# Set app reference
client_app_ref = default_app_ref
Expand Down

0 comments on commit 0f7c0f7

Please sign in to comment.