-
Notifications
You must be signed in to change notification settings - Fork 842
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add E2E tests for different strategies (#2140)
- Loading branch information
1 parent
dc0311e
commit b863fd0
Showing
5 changed files
with
196 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Flower with Tensorflow and multiple strategies testing | ||
|
||
This directory is used for testing Flower with Tensorflow by using the CIFAR10 dataset and a CNN. | ||
|
||
It uses a subset of size 1000 for the training data and 10 data points for the testing. | ||
|
||
It tests the following strategies: | ||
|
||
- FedMedian | ||
- FedTrimmedAvg | ||
- QFedAvg | ||
- FaultTolerantFedAvg | ||
- FedAvgM | ||
- FedAdam | ||
- FedAdagrad | ||
- FedYogi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import os | ||
|
||
import flwr as fl | ||
import tensorflow as tf | ||
|
||
SUBSET_SIZE = 1000 | ||
|
||
# Make TensorFlow log less verbose | ||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | ||
|
||
def get_model(): | ||
model = tf.keras.models.Sequential([ | ||
tf.keras.layers.Flatten(input_shape=(28, 28)), | ||
tf.keras.layers.Dense(128, activation='relu'), | ||
tf.keras.layers.Dense(10) | ||
]) | ||
model.compile( | ||
optimizer=tf.keras.optimizers.Adam(0.001), | ||
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], | ||
) | ||
return model | ||
|
||
model = get_model() | ||
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() | ||
x_train, y_train = x_train[:SUBSET_SIZE], y_train[:SUBSET_SIZE] | ||
x_test, y_test = x_test[:SUBSET_SIZE], y_test[:SUBSET_SIZE] | ||
|
||
|
||
# Define Flower client | ||
class FlowerClient(fl.client.NumPyClient): | ||
def get_parameters(self, config): | ||
return model.get_weights() | ||
|
||
def fit(self, parameters, config): | ||
model.set_weights(parameters) | ||
model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0) | ||
return model.get_weights(), len(x_train), {} | ||
|
||
def evaluate(self, parameters, config): | ||
model.set_weights(parameters) | ||
loss, accuracy = model.evaluate(x_test, y_test, verbose=0) | ||
return loss, len(x_test), {"accuracy": accuracy} | ||
|
||
|
||
if __name__ == "__main__": | ||
# Start Flower client | ||
fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
[build-system] | ||
requires = ["poetry-core>=1.4.0"] | ||
build-backend = "poetry.core.masonry.api" | ||
|
||
[tool.poetry] | ||
name = "quickstart_tensorflow" | ||
version = "0.1.0" | ||
description = "Keras Federated Learning Quickstart with Flower" | ||
authors = ["The Flower Authors <hello@flower.dev>"] | ||
|
||
[tool.poetry.dependencies] | ||
python = ">=3.7,<3.11" | ||
flwr = { path = "../../", develop = true, extras = ["simulation"] } | ||
tensorflow-cpu = "^2.9.1, !=2.11.1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from sys import argv | ||
|
||
import tensorflow as tf | ||
from client import SUBSET_SIZE, FlowerClient, get_model | ||
|
||
import flwr as fl | ||
from flwr.common import ndarrays_to_parameters | ||
from flwr.server.strategy import ( | ||
FaultTolerantFedAvg, | ||
FedAdagrad, | ||
FedAdam, | ||
FedAvgM, | ||
FedMedian, | ||
FedTrimmedAvg, | ||
FedYogi, | ||
QFedAvg, | ||
) | ||
|
||
STRATEGY_LIST = [ | ||
FedMedian, | ||
FedTrimmedAvg, | ||
QFedAvg, | ||
FaultTolerantFedAvg, | ||
FedAvgM, | ||
FedAdam, | ||
FedAdagrad, | ||
FedYogi, | ||
] | ||
OPT_IDX = 5 | ||
|
||
strat = argv[1] | ||
|
||
|
||
def get_strat(name): | ||
return [ | ||
(idx, strat) | ||
for idx, strat in enumerate(STRATEGY_LIST) | ||
if strat.__name__ == name | ||
][0] | ||
|
||
|
||
init_model = get_model() | ||
|
||
|
||
def client_fn(cid): | ||
_ = cid | ||
return FlowerClient() | ||
|
||
|
||
def evaluate(server_round, parameters, config): | ||
model = get_model() | ||
|
||
_, (x_test, y_test) = tf.keras.datasets.mnist.load_data() | ||
x_test, y_test = x_test[:SUBSET_SIZE], y_test[:SUBSET_SIZE] | ||
|
||
model.set_weights(parameters) | ||
|
||
loss, accuracy = model.evaluate(x_test, y_test) | ||
|
||
# return statistics | ||
return loss, {"accuracy": accuracy} | ||
|
||
|
||
strat_args = { | ||
"evaluate_fn": evaluate, | ||
"initial_parameters": ndarrays_to_parameters(init_model.get_weights()), | ||
} | ||
|
||
start_idx, strategy = get_strat(strat) | ||
|
||
if start_idx >= OPT_IDX: | ||
strat_args["tau"] = 0.01 | ||
|
||
hist = fl.simulation.start_simulation( | ||
client_fn=client_fn, | ||
num_clients=2, | ||
config=fl.server.ServerConfig(num_rounds=3), | ||
strategy=strategy(**strat_args), | ||
) | ||
|
||
assert ( | ||
hist.metrics_centralized["accuracy"][0][1] | ||
/ hist.metrics_centralized["accuracy"][-1][1] | ||
) <= 1.04 or (hist.losses_centralized[0][1] / hist.losses_centralized[-1][1]) >= 0.96 |