-
Notifications
You must be signed in to change notification settings - Fork 875
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(examples) Update Flower example for custom metrics (#3873)
Co-authored-by: jafermarq <javier@flower.ai>
- Loading branch information
1 parent
2d10f8a
commit 63dc00f
Showing
10 changed files
with
283 additions
and
239 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,113 +1,75 @@ | ||
--- | ||
title: Example Flower App with Custom Metrics | ||
tags: [basic, vision, fds] | ||
dataset: [CIFAR-10] | ||
framework: [tensorflow] | ||
framework: [tensorflow, scikit-learn] | ||
--- | ||
|
||
# Flower Example using Custom Metrics | ||
# Custom Metrics for Federated Learning with TensorFlow and Flower | ||
|
||
This simple example demonstrates how to calculate custom metrics over multiple clients beyond the traditional ones available in the ML frameworks. In this case, it demonstrates the use of ready-available `scikit-learn` metrics: accuracy, recall, precision, and f1-score. | ||
This simple example demonstrates how to calculate custom metrics over multiple clients beyond the traditional ones available in the ML frameworks. In this case, it demonstrates the use of ready-available [scikit-learn metrics](https://scikit-learn.org/stable/modules/model_evaluation.html): accuracy, recall, precision, and f1-score. | ||
|
||
Once both the test values (`y_test`) and the predictions (`y_pred`) are available on the client side (`client.py`), other metrics or custom ones are possible to be calculated. | ||
Once both the test values (`y_test`) and the predictions (`y_pred`) are available on the client side (`client_app.py`), other metrics or custom ones are possible to be calculated. | ||
|
||
The main takeaways of this implementation are: | ||
|
||
- the use of the `output_dict` on the client side - inside `evaluate` method on `client.py` | ||
- the use of the `evaluate_metrics_aggregation_fn` - to aggregate the metrics on the server side, part of the `strategy` on `server.py` | ||
- the return of multiple evaluation metrics generated at the `evaluate` method on `client_app.py` | ||
- the use of the `evaluate_metrics_aggregation_fn` - to aggregate the metrics on the server side, part of the `strategy` on `server_app.py` | ||
|
||
This example is based on the `quickstart-tensorflow` with CIFAR-10, source [here](https://flower.ai/docs/quickstart-tensorflow.html), with the addition of [Flower Datasets](https://flower.ai/docs/datasets/index.html) to retrieve the CIFAR-10. | ||
|
||
Using the CIFAR-10 dataset for classification, this is a multi-class classification problem, thus some changes on how to calculate the metrics using `average='micro'` and `np.argmax` is required. For binary classification, this is not required. Also, for unsupervised learning tasks, such as using a deep autoencoder, a custom metric based on reconstruction error could be implemented on client side. | ||
|
||
## Project Setup | ||
## Set up the project | ||
|
||
Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: | ||
### Clone the project | ||
|
||
```shell | ||
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/custom-metrics . && rm -rf flower && cd custom-metrics | ||
``` | ||
|
||
This will create a new directory called `custom-metrics` containing the following files: | ||
|
||
```shell | ||
-- pyproject.toml | ||
-- requirements.txt | ||
-- client.py | ||
-- server.py | ||
-- run.sh | ||
-- README.md | ||
``` | ||
|
||
### Installing Dependencies | ||
|
||
Project dependencies (such as `scikit-learn`, `tensorflow` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. | ||
|
||
#### Poetry | ||
Start by cloning the example project: | ||
|
||
```shell | ||
poetry install | ||
poetry shell | ||
git clone --depth=1 https://github.com/adap/flower.git _tmp \ | ||
&& mv _tmp/examples/custom-metrics . \ | ||
&& rm -rf _tmp && cd custom-metrics | ||
``` | ||
|
||
Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: | ||
This will create a new directory called `custom-metrics` containing the | ||
following files: | ||
|
||
```shell | ||
poetry run python3 -c "import flwr" | ||
custom-metrics | ||
├── README.md | ||
├── custommetrics_example | ||
│ ├── __init__.py | ||
│ ├── client_app.py # Defines your ClientApp | ||
│ ├── server_app.py # Defines your ServerApp | ||
│ └── task.py # Defines your model and dataloading functions | ||
└── pyproject.toml # Project metadata like dependencies and configs | ||
``` | ||
|
||
If you don't see any errors you're good to go! | ||
## Install dependencies and project | ||
|
||
#### pip | ||
Install the dependencies defined in `pyproject.toml` as well as the `custommetrics_example` package. | ||
|
||
Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. | ||
|
||
```shell | ||
python -m venv venv | ||
source venv/bin/activate | ||
pip install -r requirements.txt | ||
```bash | ||
pip install -e . | ||
``` | ||
|
||
## Run Federated Learning with Custom Metrics | ||
## Run the Example | ||
|
||
Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows: | ||
You can run your Flower project in both _simulation_ and _deployment_ mode without making changes to the code. If you are starting with Flower, we recommend you using the _simulation_ mode as it requires fewer components to be launched manually. By default, `flwr run` will make use of the Simulation Engine. | ||
|
||
```shell | ||
python server.py | ||
``` | ||
|
||
Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each: | ||
### Run with the Simulation Engine | ||
|
||
```shell | ||
python client.py | ||
```bash | ||
flwr run . | ||
``` | ||
|
||
Alternatively you can run all of it in one shell as follows: | ||
You can also override some of the settings for your `ClientApp` and `ServerApp` defined in `pyproject.toml`. For example: | ||
|
||
```shell | ||
python server.py & | ||
# Wait for a few seconds to give the server enough time to start, then: | ||
python client.py & | ||
python client.py | ||
```bash | ||
flwr run . --run-config num-server-rounds=5 | ||
``` | ||
|
||
or | ||
### Run with the Deployment Engine | ||
|
||
```shell | ||
chmod +x run.sh | ||
./run.sh | ||
``` | ||
|
||
You will see that Keras is starting a federated training. Have a look to the [Flower Quickstarter documentation](https://flower.ai/docs/quickstart-tensorflow.html) for a detailed explanation. You can add `steps_per_epoch=3` to `model.fit()` if you just want to evaluate that everything works without having to wait for the client-side training to finish (this will save you a lot of time during development). | ||
|
||
Running `run.sh` will result in the following output (after 3 rounds): | ||
|
||
```shell | ||
INFO flwr 2024-01-17 17:45:23,794 | app.py:228 | app_fit: metrics_distributed { | ||
'accuracy': [(1, 0.10000000149011612), (2, 0.10000000149011612), (3, 0.3393000066280365)], | ||
'acc': [(1, 0.1), (2, 0.1), (3, 0.3393)], | ||
'rec': [(1, 0.1), (2, 0.1), (3, 0.3393)], | ||
'prec': [(1, 0.1), (2, 0.1), (3, 0.3393)], | ||
'f1': [(1, 0.10000000000000002), (2, 0.10000000000000002), (3, 0.3393)] | ||
} | ||
``` | ||
> \[!NOTE\] | ||
> An update to this example will show how to run this Flower application with the Deployment Engine and TLS certificates, or with Docker. |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""custommetrics_example: A Flower / TensorFlow app for custom metrics.""" |
68 changes: 68 additions & 0 deletions
68
examples/custom-metrics/custommetrics_example/client_app.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
"""custommetrics_example: A Flower / TensorFlow app for custom metrics.""" | ||
|
||
import os | ||
|
||
import numpy as np | ||
from custommetrics_example.task import eval_learning, get_model, load_data | ||
|
||
from flwr.client import Client, ClientApp, NumPyClient | ||
from flwr.common import Context | ||
|
||
# Make TensorFlow log less verbose | ||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | ||
|
||
|
||
# Define Flower client | ||
class FlowerClient(NumPyClient): | ||
# pylint: disable=too-many-arguments | ||
def __init__(self, model, x_train, y_train, x_test, y_test): | ||
self.model = model | ||
self.x_train = x_train | ||
self.y_train = y_train | ||
self.x_test = x_test | ||
self.y_test = y_test | ||
|
||
def fit(self, parameters, config): | ||
self.model.set_weights(parameters) | ||
self.model.fit( | ||
self.x_train, self.y_train, epochs=1, batch_size=32, verbose=False | ||
) | ||
return self.model.get_weights(), len(self.x_train), {} | ||
|
||
def evaluate(self, parameters, config): | ||
self.model.set_weights(parameters) | ||
loss, accuracy = self.model.evaluate(self.x_test, self.y_test, verbose=False) | ||
y_pred = self.model.predict(self.x_test, verbose=False) | ||
y_pred = np.argmax(y_pred, axis=1).reshape( | ||
-1, 1 | ||
) # MobileNetV2 outputs 10 possible classes, argmax returns just the most probable | ||
|
||
acc, rec, prec, f1 = eval_learning(self.y_test, y_pred) | ||
output_dict = { | ||
"accuracy": accuracy, # accuracy from tensorflow model.evaluate | ||
"acc": acc, | ||
"rec": rec, | ||
"prec": prec, | ||
"f1": f1, | ||
} | ||
return loss, len(self.x_test), output_dict | ||
|
||
|
||
def client_fn(context: Context) -> Client: | ||
"""Construct a Client that will be run in a ClientApp.""" | ||
|
||
# Read the node_config to fetch data partition associated to this node | ||
partition_id = context.node_config["partition-id"] | ||
num_partitions = context.node_config["num-partitions"] | ||
|
||
# Load the train and test data | ||
x_train, y_train, x_test, y_test = load_data(partition_id, num_partitions) | ||
|
||
model = get_model() | ||
|
||
# Return Client instance | ||
return FlowerClient(model, x_train, y_train, x_test, y_test).to_client() | ||
|
||
|
||
# Create ClientApp | ||
app = ClientApp(client_fn=client_fn) |
82 changes: 82 additions & 0 deletions
82
examples/custom-metrics/custommetrics_example/server_app.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
"""custommetrics_example: A Flower / TensorFlow app for custom metrics.""" | ||
|
||
import numpy as np | ||
from custommetrics_example.task import get_model, get_parameters | ||
|
||
from flwr.common import Context, ndarrays_to_parameters | ||
from flwr.server import ServerApp, ServerAppComponents, ServerConfig | ||
from flwr.server.strategy import FedAvg | ||
|
||
|
||
# Define metrics aggregation function | ||
def average_metrics(metrics): | ||
# pylint: disable=C0301 | ||
"""Aggregate metrics from multiple clients by calculating mean averages. | ||
Parameters | ||
---------- | ||
metrics : list | ||
A list containing tuples, where each tuple represents metrics for a client. | ||
Each tuple is structured as (num_examples, metric), where: | ||
- num_examples (int) : The number of examples used to compute the metrics. | ||
- metric (dict) : A dictionary containing custom metrics provided as | ||
`output_dict` in the `evaluate` method from `client.py`. | ||
Returns | ||
------- | ||
dict | ||
A dictionary with the aggregated metrics, calculating mean averages. | ||
The keys of the dictionary represent different metrics, including: | ||
- 'accuracy': Mean accuracy calculated by TensorFlow. | ||
- 'acc': Mean accuracy from scikit-learn. | ||
- 'rec': Mean recall from scikit-learn. | ||
- 'prec': Mean precision from scikit-learn. | ||
- 'f1': Mean F1 score from scikit-learn. | ||
Note: If a weighted average is required, the `num_examples` parameter can be | ||
leveraged. | ||
Example: | ||
Example `metrics` list for two clients after the last round: | ||
[(10000, {'prec': 0.108, 'acc': 0.108, 'f1': 0.108, 'accuracy': 0.1080000028014183, 'rec': 0.108}), | ||
(10000, {'f1': 0.108, 'rec': 0.108, 'accuracy': 0.1080000028014183, 'prec': 0.108, 'acc': 0.108})] | ||
""" | ||
|
||
# Here num_examples are not taken into account by using _ | ||
accuracies_tf = np.mean([metric["accuracy"] for _, metric in metrics]) | ||
accuracies = np.mean([metric["acc"] for _, metric in metrics]) | ||
recalls = np.mean([metric["rec"] for _, metric in metrics]) | ||
precisions = np.mean([metric["prec"] for _, metric in metrics]) | ||
f1s = np.mean([metric["f1"] for _, metric in metrics]) | ||
|
||
return { | ||
"accuracy": accuracies_tf, | ||
"acc": accuracies, | ||
"rec": recalls, | ||
"prec": precisions, | ||
"f1": f1s, | ||
} | ||
|
||
|
||
def server_fn(context: Context) -> ServerAppComponents: | ||
"""Construct components that set the ServerApp behaviour.""" | ||
|
||
# Read from config | ||
num_rounds = context.run_config["num-server-rounds"] | ||
|
||
model = get_model() | ||
ndarrays = get_parameters(model) | ||
global_model_init = ndarrays_to_parameters(ndarrays) | ||
|
||
# Define strategy and the custom aggregation function for the evaluation metrics | ||
strategy = FedAvg( | ||
evaluate_metrics_aggregation_fn=average_metrics, | ||
initial_parameters=global_model_init, | ||
) | ||
config = ServerConfig(num_rounds=num_rounds) | ||
|
||
return ServerAppComponents(strategy=strategy, config=config) | ||
|
||
|
||
# Create ServerApp | ||
app = ServerApp(server_fn=server_fn) |
Oops, something went wrong.