Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add E2E tests for different strategies #2140

Merged
merged 40 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
ecf98d8
Add E2E tests for different strategies
charlesbvll Jul 18, 2023
3fb0d33
Add evaluate function
charlesbvll Jul 18, 2023
510770e
Use correct function for evaluation
charlesbvll Jul 18, 2023
3aaf704
Add strategies
charlesbvll Jul 18, 2023
7d6502b
Merge branch 'main' into add-strategies-e2e
charlesbvll Jul 18, 2023
50225cc
Fix typo
charlesbvll Jul 18, 2023
6c9d455
Convert ndarrays to parameters
charlesbvll Jul 18, 2023
8152262
Make clients less verbose and print strategy
charlesbvll Jul 18, 2023
8a53a49
Set higher tau value
charlesbvll Jul 18, 2023
abea8c3
Use tau only for opt strategies
charlesbvll Jul 18, 2023
ef30817
Increase the number of local epochs
charlesbvll Jul 18, 2023
d26a911
Increase the timeout time for strategies
charlesbvll Jul 18, 2023
2298b57
Make assertion less strict
charlesbvll Jul 18, 2023
c54f8c4
Use bigger test set for evaluation
charlesbvll Jul 18, 2023
d61b743
Increase timeout for strategies
charlesbvll Jul 18, 2023
fc148f7
Use accuracy instead of loss for assertion
charlesbvll Jul 18, 2023
b46f314
Correctly access the metrics dict
charlesbvll Jul 18, 2023
9fea6bd
Merge branch 'main' into add-strategies-e2e
charlesbvll Jul 19, 2023
67948ab
Merge branch 'main' into add-strategies-e2e
charlesbvll Jul 19, 2023
623b0c7
Use args
charlesbvll Jul 19, 2023
b890448
Use separated jobs
charlesbvll Jul 19, 2023
163e237
Use opt_idx to add tau
charlesbvll Jul 19, 2023
d89e7f1
Use loss as well as accuracy
charlesbvll Jul 19, 2023
7f7b7e7
Merge branch 'main' into add-strategies-e2e
charlesbvll Jul 19, 2023
a761c03
Use better names
charlesbvll Jul 19, 2023
a43c103
Merge branch 'add-strategies-e2e' of https://github.com/adap/flower i…
charlesbvll Jul 19, 2023
c737d9f
Fix typo
charlesbvll Jul 19, 2023
6f72475
Update README
charlesbvll Jul 19, 2023
837155e
Increase timeout length
charlesbvll Jul 19, 2023
6981e33
Replace CIFAR10 with MNIST
charlesbvll Jul 19, 2023
278515c
Use another model
charlesbvll Jul 19, 2023
b83ea27
Train 1 local epoch
charlesbvll Jul 19, 2023
d641f34
Merge branch 'main' into add-strategies-e2e
charlesbvll Jul 19, 2023
d86f580
Merge branch 'main' into add-strategies-e2e
charlesbvll Jul 20, 2023
32bafde
Merge branch 'main' into add-strategies-e2e
charlesbvll Jul 20, 2023
8147a50
Merge branch 'main' into add-strategies-e2e
charlesbvll Jul 20, 2023
16225fc
Merge branch 'main' into add-strategies-e2e
charlesbvll Jul 21, 2023
e2d8fd6
Merge branch 'main' into add-strategies-e2e
charlesbvll Jul 27, 2023
9556deb
Merge branch 'main' into add-strategies-e2e
charlesbvll Jul 31, 2023
3671053
Merge branch 'main' into add-strategies-e2e
danieljanes Jul 31, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions .github/workflows/e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ env:
FLWR_TELEMETRY_ENABLED: 0

jobs:
e2e:
frameworks:
runs-on: ubuntu-22.04
timeout-minutes: 10
# Using approach described here:
Expand Down Expand Up @@ -71,7 +71,7 @@ jobs:
Path('data').mkdir(exist_ok=True)
load_iris(as_frame=True)['data'].to_csv('./data/client.csv')

name: ${{matrix.directory}}
name: Framework / ${{matrix.directory}}

defaults:
run:
Expand All @@ -92,3 +92,35 @@ jobs:
run: ./test.sh
- name: Run virtual client test
run: python simulation.py

strategies:
runs-on: ubuntu-22.04
timeout-minutes: 10
strategy:
matrix:
strat: ["FedMedian", "FedTrimmedAvg", "QFedAvg", "FaultTolerantFedAvg", "FedAvgM", "FedAdam", "FedAdagrad", "FedYogi"]

name: Strategy / ${{ matrix.strat }}

defaults:
run:
working-directory: e2e/strategies

steps:
- uses: actions/checkout@v3
- name: Bootstrap
uses: ./.github/actions/bootstrap
- name: Install dependencies
run: |
python -m poetry install
- name: Cache Datasets
uses: actions/cache@v2
with:
path: "~/.keras"
key: keras-datasets
- name: Download Datasets
run: |
python -c "import tensorflow as tf; tf.keras.datasets.mnist.load_data()"
- name: Test strategies
run: |
python test.py "${{ matrix.strat }}"
16 changes: 16 additions & 0 deletions e2e/strategies/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Flower with Tensorflow and multiple strategies testing

This directory is used for testing Flower with Tensorflow by using the CIFAR10 dataset and a CNN.

It uses a subset of size 1000 for the training data and 10 data points for the testing.

It tests the following strategies:

- FedMedian
- FedTrimmedAvg
- QFedAvg
- FaultTolerantFedAvg
- FedAvgM
- FedAdam
- FedAdagrad
- FedYogi
48 changes: 48 additions & 0 deletions e2e/strategies/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os

import flwr as fl
import tensorflow as tf

SUBSET_SIZE = 1000

# Make TensorFlow log less verbose
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

def get_model():
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
return model

model = get_model()
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, y_train = x_train[:SUBSET_SIZE], y_train[:SUBSET_SIZE]
x_test, y_test = x_test[:SUBSET_SIZE], y_test[:SUBSET_SIZE]


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
def get_parameters(self, config):
return model.get_weights()

def fit(self, parameters, config):
model.set_weights(parameters)
model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0)
return model.get_weights(), len(x_train), {}

def evaluate(self, parameters, config):
model.set_weights(parameters)
loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
return loss, len(x_test), {"accuracy": accuracy}


if __name__ == "__main__":
# Start Flower client
fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient())
14 changes: 14 additions & 0 deletions e2e/strategies/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "quickstart_tensorflow"
version = "0.1.0"
description = "Keras Federated Learning Quickstart with Flower"
authors = ["The Flower Authors <hello@flower.dev>"]

[tool.poetry.dependencies]
python = ">=3.7,<3.11"
flwr = { path = "../../", develop = true, extras = ["simulation"] }
tensorflow-cpu = "^2.9.1, !=2.11.1"
84 changes: 84 additions & 0 deletions e2e/strategies/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from sys import argv

import tensorflow as tf
from client import SUBSET_SIZE, FlowerClient, get_model

import flwr as fl
from flwr.common import ndarrays_to_parameters
from flwr.server.strategy import (
FaultTolerantFedAvg,
FedAdagrad,
FedAdam,
FedAvgM,
FedMedian,
FedTrimmedAvg,
FedYogi,
QFedAvg,
)

STRATEGY_LIST = [
FedMedian,
FedTrimmedAvg,
QFedAvg,
FaultTolerantFedAvg,
FedAvgM,
FedAdam,
FedAdagrad,
FedYogi,
]
OPT_IDX = 5

strat = argv[1]


def get_strat(name):
return [
(idx, strat)
for idx, strat in enumerate(STRATEGY_LIST)
if strat.__name__ == name
][0]


init_model = get_model()


def client_fn(cid):
_ = cid
return FlowerClient()


def evaluate(server_round, parameters, config):
model = get_model()

_, (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_test, y_test = x_test[:SUBSET_SIZE], y_test[:SUBSET_SIZE]

model.set_weights(parameters)

loss, accuracy = model.evaluate(x_test, y_test)

# return statistics
return loss, {"accuracy": accuracy}


strat_args = {
"evaluate_fn": evaluate,
"initial_parameters": ndarrays_to_parameters(init_model.get_weights()),
}

start_idx, strategy = get_strat(strat)

if start_idx >= OPT_IDX:
strat_args["tau"] = 0.01

hist = fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=2,
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy(**strat_args),
)

assert (
hist.metrics_centralized["accuracy"][0][1]
/ hist.metrics_centralized["accuracy"][-1][1]
) <= 1.04 or (hist.losses_centralized[0][1] / hist.losses_centralized[-1][1]) >= 0.96
Loading