Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MLX quickstart example #2693

Merged
merged 14 commits into from
Dec 8, 2023
85 changes: 85 additions & 0 deletions examples/quickstart-mlx/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Flower Example using MLX

This introductory example to Flower uses [MLX](https://ml-explore.github.io/mlx/build/html/index.html), but deep knowledge of MLX is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy.

[MLX](https://ml-explore.github.io/mlx/build/html/index.html) is a NumPy-like array framework designed for efficient and flexible machine learning on Apple silicon.

In this example, we will train a simple 2 layers MLP on MNIST data (handwritten digits recognition).

## Project Setup

Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:

```shell
git clone --depth=1 https://github.com/adap/flower.git _tmp && mv _tmp/examples/quickstart-mlx . && rm -rf _tmp && cd quickstart-mlx
```

This will create a new directory called `quickstart-mlx` containing the following files:

```shell
-- pyproject.toml
-- requirements.txt
-- client.py
-- server.py
-- README.md
```

### Installing Dependencies

Project dependencies (such as `mlx` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.

#### Poetry

```shell
poetry install
poetry shell
```

Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command:

```shell
poetry run python3 -c "import flwr"
```

If you don't see any errors you're good to go!

#### pip

Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt.

```shell
pip install -r requirements.txt
```

## Run Federated Learning with MLX and Flower

Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows:

```shell
python3 server.py
```

Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminal windows and run the
following commands.

Start a first client in the first terminal:

```shell
python3 client.py
```

And another one in the second terminal:

```shell
python3 client.py
```

If you want to utilize your GPU, you can use the `--gpu` argument:

```shell
python3 client.py --gpu
```

Note that you can start many more clients if you want, but each will have to be in its own terminal.

You will see that MLX is starting a federated training. Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-mlx) for a detailed explanation.
119 changes: 119 additions & 0 deletions examples/quickstart-mlx/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import argparse

import flwr as fl
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np

import mnist
charlesbvll marked this conversation as resolved.
Show resolved Hide resolved


class MLP(nn.Module):
"""A simple MLP."""

def __init__(
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
):
super().__init__()
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]

def __call__(self, x):
for l in self.layers[:-1]:
x = mx.maximum(l(x), 0.0)
return self.layers[-1](x)


def loss_fn(model, X, y):
return mx.mean(nn.losses.cross_entropy(model(X), y))


def eval_fn(model, X, y):
return mx.mean(mx.argmax(model(X), axis=1) == y)


def batch_iterate(batch_size, X, y):
perm = mx.array(np.random.permutation(y.size))
for s in range(0, y.size, batch_size):
ids = perm[s : s + batch_size]
yield X[ids], y[ids]


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
def __init__(
self, model, optim, loss_and_grad_fn, data, num_epochs, batch_size
) -> None:
self.model = model
self.optimizer = optim
self.loss_and_grad_fn = loss_and_grad_fn
self.train_images, self.train_labels, self.test_images, self.test_labels = data
self.num_epochs = num_epochs
self.batch_size = batch_size

def get_parameters(self, config):
layers = self.model.parameters()["layers"]
return [np.array(val) for layer in layers for _, val in layer.items()]

def set_parameters(self, parameters):
new_params = {}
new_params["layers"] = [
{"weight": mx.array(parameters[i]), "bias": mx.array(parameters[i + 1])}
for i in range(0, len(parameters), 2)
]
self.model.update(new_params)

def fit(self, parameters, config):
self.set_parameters(parameters)
for _ in range(self.num_epochs):
for X, y in batch_iterate(
self.batch_size, self.train_images, self.train_labels
):
loss, grads = self.loss_and_grad_fn(self.model, X, y)
self.optimizer.update(self.model, grads)
mx.eval(self.model.parameters(), self.optimizer.state)
return self.get_parameters(config={}), len(self.train_images), {}

def evaluate(self, parameters, config):
self.set_parameters(parameters)
accuracy = eval_fn(self.model, self.test_images, self.test_labels)
loss = loss_fn(self.model, self.test_images, self.test_labels)
return loss.item(), len(self.test_images), {"accuracy": accuracy.item()}


if __name__ == "__main__":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's put the following code in a main function

parser = argparse.ArgumentParser("Train a simple MLP on MNIST with MLX.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
charlesbvll marked this conversation as resolved.
Show resolved Hide resolved
if not args.gpu:
mx.set_default_device(mx.cpu)

num_layers = 2
hidden_dim = 32
num_classes = 10
batch_size = 256
num_epochs = 1
learning_rate = 1e-1

train_images, train_labels, test_images, test_labels = map(mx.array, mnist.mnist())
charlesbvll marked this conversation as resolved.
Show resolved Hide resolved
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)

loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.SGD(learning_rate=learning_rate)

# Start Flower client
fl.client.start_numpy_client(
server_address="127.0.0.1:8080",
client=FlowerClient(
model,
optimizer,
loss_and_grad_fn,
(train_images, train_labels, test_images, test_labels),
num_epochs,
batch_size,
),
)
69 changes: 69 additions & 0 deletions examples/quickstart-mlx/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright © 2023 Apple Inc.

import gzip
import numpy as np
import os
import pickle
from urllib import request


def mnist(save_dir="/tmp"):
"""
Load the MNIST dataset in 4 tensors: train images, train labels,
test images, and test labels.

Checks `save_dir` for already downloaded data otherwise downloads.

Download code modified from:
https://github.com/hsjeong5/MNIST-for-Numpy
"""

def download_and_save(save_file):
base_url = "http://yann.lecun.com/exdb/mnist/"
filename = [
["training_images", "train-images-idx3-ubyte.gz"],
["test_images", "t10k-images-idx3-ubyte.gz"],
["training_labels", "train-labels-idx1-ubyte.gz"],
["test_labels", "t10k-labels-idx1-ubyte.gz"],
]

mnist = {}
for name in filename:
out_file = os.path.join("/tmp", name[1])
request.urlretrieve(base_url + name[1], out_file)
for name in filename[:2]:
out_file = os.path.join("/tmp", name[1])
with gzip.open(out_file, "rb") as f:
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(
-1, 28 * 28
)
for name in filename[-2:]:
out_file = os.path.join("/tmp", name[1])
with gzip.open(out_file, "rb") as f:
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
with open(save_file, "wb") as f:
pickle.dump(mnist, f)

save_file = os.path.join(save_dir, "mnist.pkl")
if not os.path.exists(save_file):
download_and_save(save_file)
with open(save_file, "rb") as f:
mnist = pickle.load(f)

preproc = lambda x: x.astype(np.float32) / 255.0
mnist["training_images"] = preproc(mnist["training_images"])
mnist["test_images"] = preproc(mnist["test_images"])
return (
mnist["training_images"],
mnist["training_labels"].astype(np.uint32),
mnist["test_images"],
mnist["test_labels"].astype(np.uint32),
)


if __name__ == "__main__":
train_x, train_y, test_x, test_y = mnist()
assert train_x.shape == (60000, 28 * 28), "Wrong training set size"
assert train_y.shape == (60000,), "Wrong training set size"
assert test_x.shape == (10000, 28 * 28), "Wrong test set size"
assert test_y.shape == (10000,), "Wrong test set size"
15 changes: 15 additions & 0 deletions examples/quickstart-mlx/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "quickstart-mlx"
version = "0.1.0"
description = "MLX Federated Learning Quickstart with Flower"
authors = ["The Flower Authors <hello@flower.dev>"]

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
flwr = ">=1.0,<2.0"
mlx = "==0.0.3"
numpy = "==1.24.4"
charlesbvll marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 3 additions & 0 deletions examples/quickstart-mlx/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
flwr>=1.0, <2.0
mlx==0.0.3
numpy==1.24.4
17 changes: 17 additions & 0 deletions examples/quickstart-mlx/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash
set -e
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/

echo "Starting server"
python server.py &
sleep 3 # Sleep for 3s to give the server enough time to start

for i in $(seq 0 1); do
echo "Starting client $i"
python client.py &
done

# Enable CTRL+C to stop all background processes
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM
# Wait for all background processes to complete
wait
25 changes: 25 additions & 0 deletions examples/quickstart-mlx/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import List, Tuple

import flwr as fl
from flwr.common import Metrics


# Define metric aggregation function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
# Multiply accuracy of each client by number of examples used
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]

# Aggregate and return custom metric (weighted average)
return {"accuracy": sum(accuracies) / sum(examples)}


# Define strategy
strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average)

# Start Flower server
fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
)