Skip to content

Commit

Permalink
refactor(examples) Update Flower example for custom metrics (#3873)
Browse files Browse the repository at this point in the history
Co-authored-by: jafermarq <javier@flower.ai>
  • Loading branch information
chongshenng and jafermarq authored Aug 9, 2024
1 parent 2d10f8a commit 63dc00f
Show file tree
Hide file tree
Showing 10 changed files with 283 additions and 239 deletions.
112 changes: 37 additions & 75 deletions examples/custom-metrics/README.md
Original file line number Diff line number Diff line change
@@ -1,113 +1,75 @@
---
title: Example Flower App with Custom Metrics
tags: [basic, vision, fds]
dataset: [CIFAR-10]
framework: [tensorflow]
framework: [tensorflow, scikit-learn]
---

# Flower Example using Custom Metrics
# Custom Metrics for Federated Learning with TensorFlow and Flower

This simple example demonstrates how to calculate custom metrics over multiple clients beyond the traditional ones available in the ML frameworks. In this case, it demonstrates the use of ready-available `scikit-learn` metrics: accuracy, recall, precision, and f1-score.
This simple example demonstrates how to calculate custom metrics over multiple clients beyond the traditional ones available in the ML frameworks. In this case, it demonstrates the use of ready-available [scikit-learn metrics](https://scikit-learn.org/stable/modules/model_evaluation.html): accuracy, recall, precision, and f1-score.

Once both the test values (`y_test`) and the predictions (`y_pred`) are available on the client side (`client.py`), other metrics or custom ones are possible to be calculated.
Once both the test values (`y_test`) and the predictions (`y_pred`) are available on the client side (`client_app.py`), other metrics or custom ones are possible to be calculated.

The main takeaways of this implementation are:

- the use of the `output_dict` on the client side - inside `evaluate` method on `client.py`
- the use of the `evaluate_metrics_aggregation_fn` - to aggregate the metrics on the server side, part of the `strategy` on `server.py`
- the return of multiple evaluation metrics generated at the `evaluate` method on `client_app.py`
- the use of the `evaluate_metrics_aggregation_fn` - to aggregate the metrics on the server side, part of the `strategy` on `server_app.py`

This example is based on the `quickstart-tensorflow` with CIFAR-10, source [here](https://flower.ai/docs/quickstart-tensorflow.html), with the addition of [Flower Datasets](https://flower.ai/docs/datasets/index.html) to retrieve the CIFAR-10.

Using the CIFAR-10 dataset for classification, this is a multi-class classification problem, thus some changes on how to calculate the metrics using `average='micro'` and `np.argmax` is required. For binary classification, this is not required. Also, for unsupervised learning tasks, such as using a deep autoencoder, a custom metric based on reconstruction error could be implemented on client side.

## Project Setup
## Set up the project

Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:
### Clone the project

```shell
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/custom-metrics . && rm -rf flower && cd custom-metrics
```

This will create a new directory called `custom-metrics` containing the following files:

```shell
-- pyproject.toml
-- requirements.txt
-- client.py
-- server.py
-- run.sh
-- README.md
```

### Installing Dependencies

Project dependencies (such as `scikit-learn`, `tensorflow` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.

#### Poetry
Start by cloning the example project:

```shell
poetry install
poetry shell
git clone --depth=1 https://github.com/adap/flower.git _tmp \
&& mv _tmp/examples/custom-metrics . \
&& rm -rf _tmp && cd custom-metrics
```

Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command:
This will create a new directory called `custom-metrics` containing the
following files:

```shell
poetry run python3 -c "import flwr"
custom-metrics
├── README.md
├── custommetrics_example
│ ├── __init__.py
│ ├── client_app.py # Defines your ClientApp
│ ├── server_app.py # Defines your ServerApp
│ └── task.py # Defines your model and dataloading functions
└── pyproject.toml # Project metadata like dependencies and configs
```

If you don't see any errors you're good to go!
## Install dependencies and project

#### pip
Install the dependencies defined in `pyproject.toml` as well as the `custommetrics_example` package.

Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt.

```shell
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt
```bash
pip install -e .
```

## Run Federated Learning with Custom Metrics
## Run the Example

Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows:
You can run your Flower project in both _simulation_ and _deployment_ mode without making changes to the code. If you are starting with Flower, we recommend you using the _simulation_ mode as it requires fewer components to be launched manually. By default, `flwr run` will make use of the Simulation Engine.

```shell
python server.py
```

Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each:
### Run with the Simulation Engine

```shell
python client.py
```bash
flwr run .
```

Alternatively you can run all of it in one shell as follows:
You can also override some of the settings for your `ClientApp` and `ServerApp` defined in `pyproject.toml`. For example:

```shell
python server.py &
# Wait for a few seconds to give the server enough time to start, then:
python client.py &
python client.py
```bash
flwr run . --run-config num-server-rounds=5
```

or
### Run with the Deployment Engine

```shell
chmod +x run.sh
./run.sh
```

You will see that Keras is starting a federated training. Have a look to the [Flower Quickstarter documentation](https://flower.ai/docs/quickstart-tensorflow.html) for a detailed explanation. You can add `steps_per_epoch=3` to `model.fit()` if you just want to evaluate that everything works without having to wait for the client-side training to finish (this will save you a lot of time during development).

Running `run.sh` will result in the following output (after 3 rounds):

```shell
INFO flwr 2024-01-17 17:45:23,794 | app.py:228 | app_fit: metrics_distributed {
'accuracy': [(1, 0.10000000149011612), (2, 0.10000000149011612), (3, 0.3393000066280365)],
'acc': [(1, 0.1), (2, 0.1), (3, 0.3393)],
'rec': [(1, 0.1), (2, 0.1), (3, 0.3393)],
'prec': [(1, 0.1), (2, 0.1), (3, 0.3393)],
'f1': [(1, 0.10000000000000002), (2, 0.10000000000000002), (3, 0.3393)]
}
```
> \[!NOTE\]
> An update to this example will show how to run this Flower application with the Deployment Engine and TLS certificates, or with Docker.
73 changes: 0 additions & 73 deletions examples/custom-metrics/client.py

This file was deleted.

1 change: 1 addition & 0 deletions examples/custom-metrics/custommetrics_example/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""custommetrics_example: A Flower / TensorFlow app for custom metrics."""
68 changes: 68 additions & 0 deletions examples/custom-metrics/custommetrics_example/client_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""custommetrics_example: A Flower / TensorFlow app for custom metrics."""

import os

import numpy as np
from custommetrics_example.task import eval_learning, get_model, load_data

from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Context

# Make TensorFlow log less verbose
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"


# Define Flower client
class FlowerClient(NumPyClient):
# pylint: disable=too-many-arguments
def __init__(self, model, x_train, y_train, x_test, y_test):
self.model = model
self.x_train = x_train
self.y_train = y_train
self.x_test = x_test
self.y_test = y_test

def fit(self, parameters, config):
self.model.set_weights(parameters)
self.model.fit(
self.x_train, self.y_train, epochs=1, batch_size=32, verbose=False
)
return self.model.get_weights(), len(self.x_train), {}

def evaluate(self, parameters, config):
self.model.set_weights(parameters)
loss, accuracy = self.model.evaluate(self.x_test, self.y_test, verbose=False)
y_pred = self.model.predict(self.x_test, verbose=False)
y_pred = np.argmax(y_pred, axis=1).reshape(
-1, 1
) # MobileNetV2 outputs 10 possible classes, argmax returns just the most probable

acc, rec, prec, f1 = eval_learning(self.y_test, y_pred)
output_dict = {
"accuracy": accuracy, # accuracy from tensorflow model.evaluate
"acc": acc,
"rec": rec,
"prec": prec,
"f1": f1,
}
return loss, len(self.x_test), output_dict


def client_fn(context: Context) -> Client:
"""Construct a Client that will be run in a ClientApp."""

# Read the node_config to fetch data partition associated to this node
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]

# Load the train and test data
x_train, y_train, x_test, y_test = load_data(partition_id, num_partitions)

model = get_model()

# Return Client instance
return FlowerClient(model, x_train, y_train, x_test, y_test).to_client()


# Create ClientApp
app = ClientApp(client_fn=client_fn)
82 changes: 82 additions & 0 deletions examples/custom-metrics/custommetrics_example/server_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""custommetrics_example: A Flower / TensorFlow app for custom metrics."""

import numpy as np
from custommetrics_example.task import get_model, get_parameters

from flwr.common import Context, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg


# Define metrics aggregation function
def average_metrics(metrics):
# pylint: disable=C0301
"""Aggregate metrics from multiple clients by calculating mean averages.
Parameters
----------
metrics : list
A list containing tuples, where each tuple represents metrics for a client.
Each tuple is structured as (num_examples, metric), where:
- num_examples (int) : The number of examples used to compute the metrics.
- metric (dict) : A dictionary containing custom metrics provided as
`output_dict` in the `evaluate` method from `client.py`.
Returns
-------
dict
A dictionary with the aggregated metrics, calculating mean averages.
The keys of the dictionary represent different metrics, including:
- 'accuracy': Mean accuracy calculated by TensorFlow.
- 'acc': Mean accuracy from scikit-learn.
- 'rec': Mean recall from scikit-learn.
- 'prec': Mean precision from scikit-learn.
- 'f1': Mean F1 score from scikit-learn.
Note: If a weighted average is required, the `num_examples` parameter can be
leveraged.
Example:
Example `metrics` list for two clients after the last round:
[(10000, {'prec': 0.108, 'acc': 0.108, 'f1': 0.108, 'accuracy': 0.1080000028014183, 'rec': 0.108}),
(10000, {'f1': 0.108, 'rec': 0.108, 'accuracy': 0.1080000028014183, 'prec': 0.108, 'acc': 0.108})]
"""

# Here num_examples are not taken into account by using _
accuracies_tf = np.mean([metric["accuracy"] for _, metric in metrics])
accuracies = np.mean([metric["acc"] for _, metric in metrics])
recalls = np.mean([metric["rec"] for _, metric in metrics])
precisions = np.mean([metric["prec"] for _, metric in metrics])
f1s = np.mean([metric["f1"] for _, metric in metrics])

return {
"accuracy": accuracies_tf,
"acc": accuracies,
"rec": recalls,
"prec": precisions,
"f1": f1s,
}


def server_fn(context: Context) -> ServerAppComponents:
"""Construct components that set the ServerApp behaviour."""

# Read from config
num_rounds = context.run_config["num-server-rounds"]

model = get_model()
ndarrays = get_parameters(model)
global_model_init = ndarrays_to_parameters(ndarrays)

# Define strategy and the custom aggregation function for the evaluation metrics
strategy = FedAvg(
evaluate_metrics_aggregation_fn=average_metrics,
initial_parameters=global_model_init,
)
config = ServerConfig(num_rounds=num_rounds)

return ServerAppComponents(strategy=strategy, config=config)


# Create ServerApp
app = ServerApp(server_fn=server_fn)
Loading

0 comments on commit 63dc00f

Please sign in to comment.