-
Notifications
You must be signed in to change notification settings - Fork 875
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8a039ed
commit 625dca1
Showing
7 changed files
with
213 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Flower with Pandas testing | ||
|
||
This directory is used for testing Flower with Pandas by computing an historgram of some simple Iris dataset. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import warnings | ||
from typing import Dict, List, Tuple | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
import flwr as fl | ||
|
||
|
||
df = pd.read_csv("./data/client.csv") | ||
|
||
column_names = ["sepal length (cm)", "sepal width (cm)"] | ||
|
||
|
||
def compute_hist(df: pd.DataFrame, col_name: str) -> np.ndarray: | ||
freqs, _ = np.histogram(df[col_name]) | ||
return freqs | ||
|
||
|
||
# Define Flower client | ||
class FlowerClient(fl.client.NumPyClient): | ||
def fit( | ||
self, parameters: List[np.ndarray], config: Dict[str, str] | ||
) -> Tuple[List[np.ndarray], int, Dict]: | ||
hist_list = [] | ||
# Execute query locally | ||
for c in column_names: | ||
hist = compute_hist(df, c) | ||
hist_list.append(hist) | ||
return ( | ||
hist_list, | ||
len(df), | ||
{}, | ||
) | ||
|
||
if __name__ == "__main__": | ||
# Start Flower client | ||
fl.client.start_numpy_client( | ||
server_address="127.0.0.1:8080", | ||
client=FlowerClient(), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
[build-system] | ||
requires = ["poetry-core>=1.4.0"] | ||
build-backend = "poetry.core.masonry.api" | ||
|
||
[tool.poetry] | ||
name = "quickstart_pandas" | ||
version = "0.1.0" | ||
description = "Pandas Federated Analytics Quickstart with Flower" | ||
authors = ["Ragy Haddad <ragy202@gmail.com>"] | ||
maintainers = ["The Flower Authors <hello@flower.dev>"] | ||
|
||
[tool.poetry.dependencies] | ||
python = "^3.8" | ||
flwr = { path = "../../", develop = true, extras = ["simulation"] } | ||
numpy = "^1.21.0" | ||
pandas = "^2.0.0" | ||
scikit-learn = "^1.1.1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import flwr as fl | ||
|
||
from strategy import FedAnalytics | ||
|
||
# Start Flower server | ||
hist = fl.server.start_server( | ||
server_address="0.0.0.0:8080", | ||
config=fl.server.ServerConfig(num_rounds=1), | ||
strategy=FedAnalytics(), | ||
) | ||
assert hist.metrics_centralized["Aggregated histograms"][1][1] == [ | ||
"Length:", | ||
"18", | ||
"46", | ||
"28", | ||
"54", | ||
"32", | ||
"52", | ||
"36", | ||
"12", | ||
"10", | ||
"12", | ||
"Width:", | ||
"8", | ||
"14", | ||
"44", | ||
"48", | ||
"74", | ||
"62", | ||
"20", | ||
"22", | ||
"4", | ||
"4", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from typing import Callable, Dict, List, Optional, Tuple, Union | ||
|
||
import numpy as np | ||
|
||
from flwr.common import ( | ||
EvaluateIns, | ||
EvaluateRes, | ||
FitIns, | ||
FitRes, | ||
Parameters, | ||
Scalar, | ||
ndarrays_to_parameters, | ||
parameters_to_ndarrays, | ||
) | ||
from flwr.server.client_manager import ClientManager | ||
from flwr.server.client_proxy import ClientProxy | ||
from flwr.server.strategy import Strategy | ||
|
||
|
||
class FedAnalytics(Strategy): | ||
def __init__( | ||
self, compute_fns: List[Callable] = None, col_names: List[str] = None | ||
) -> None: | ||
super().__init__() | ||
|
||
def initialize_parameters( | ||
self, client_manager: Optional[ClientManager] = None | ||
) -> Optional[Parameters]: | ||
return None | ||
|
||
def configure_fit( | ||
self, server_round: int, parameters: Parameters, client_manager: ClientManager | ||
) -> List[Tuple[ClientProxy, FitIns]]: | ||
config = {} | ||
fit_ins = FitIns(parameters, config) | ||
clients = client_manager.sample(num_clients=2, min_num_clients=2) | ||
return [(client, fit_ins) for client in clients] | ||
|
||
def aggregate_fit( | ||
self, | ||
server_round: int, | ||
results: List[Tuple[ClientProxy, FitRes]], | ||
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], | ||
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: | ||
# Get results from fit | ||
# Convert results | ||
values_aggregated = [ | ||
(parameters_to_ndarrays(fit_res.parameters)) for _, fit_res in results | ||
] | ||
length_agg_hist = 0 | ||
width_agg_hist = 0 | ||
for val in values_aggregated: | ||
length_agg_hist += val[0] | ||
width_agg_hist += val[1] | ||
|
||
ndarr = np.concatenate( | ||
(["Length:"], length_agg_hist, ["Width:"], width_agg_hist) | ||
) | ||
return ndarrays_to_parameters(ndarr), {} | ||
|
||
def evaluate( | ||
self, server_round: int, parameters: Parameters | ||
) -> Optional[Tuple[float, Dict[str, Scalar]]]: | ||
agg_hist = [arr.item() for arr in parameters_to_ndarrays(parameters)] | ||
return 0, {"Aggregated histograms": agg_hist} | ||
|
||
def configure_evaluate( | ||
self, server_round: int, parameters: Parameters, client_manager: ClientManager | ||
) -> List[Tuple[ClientProxy, EvaluateIns]]: | ||
pass | ||
|
||
def aggregate_evaluate( | ||
self, | ||
server_round: int, | ||
results: List[Tuple[ClientProxy, EvaluateRes]], | ||
failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], | ||
) -> Tuple[Optional[float], Dict[str, Scalar]]: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/bin/bash | ||
set -e | ||
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ | ||
|
||
bash ../test.sh |