Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Secaggplus example #2183

Merged
merged 60 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
6c9ed36
Secaggplus (#2156)
panh99 Jul 21, 2023
80da399
Update src/py/flwr/client/secure_aggregation/__init__.py
panh99 Jul 30, 2023
78094c7
Merge branch 'main' into secaggplus
danieljanes Aug 1, 2023
822ff6d
rename and create consts
panh99 Aug 2, 2023
ee420b2
reset serde.py to main
panh99 Aug 2, 2023
6bed121
remove secagg example
panh99 Aug 2, 2023
62e2b54
init example
panh99 Aug 2, 2023
cfb57a6
add readme for sa example
panh99 Aug 2, 2023
c068161
Merge branch 'main' into secaggplus
panh99 Aug 2, 2023
b22ee04
Merge branch 'secaggplus' into secaggplus-example
panh99 Aug 3, 2023
89031cd
update name and re-format
panh99 Aug 3, 2023
b93bcc9
fix invalid name error
panh99 Aug 3, 2023
8a2c079
Add unittests for secaggplus handler
panh99 Aug 4, 2023
5224cc2
Merge branch 'main' into secaggplus
panh99 Aug 4, 2023
e25114a
Add content checks for SA messages
panh99 Aug 8, 2023
8b15810
update unit tests for SA handler
panh99 Aug 9, 2023
2c3ec2d
Merge branch 'main' into secaggplus
panh99 Aug 9, 2023
45ea7e2
Merge branch 'secaggplus' into secaggplus-example
panh99 Aug 10, 2023
bffd355
fix a bug causing exception in the collect masked input stage
panh99 Aug 10, 2023
c4e2afd
format secagg-mt-pytorch
panh99 Aug 10, 2023
7364f6f
Add underscores to constants
danieljanes Aug 14, 2023
9eeb6fd
remove and update comments
panh99 Aug 14, 2023
8ec786c
Update src/py/flwr/client/secure_aggregation/secaggplus_handler_test.py
panh99 Aug 14, 2023
a26a1b9
Update src/py/flwr/client/secure_aggregation/secaggplus_handler.py
panh99 Aug 14, 2023
b394eb8
Update src/py/flwr/client/secure_aggregation/secaggplus_handler.py
panh99 Aug 14, 2023
004356d
Update src/py/flwr/client/secure_aggregation/secaggplus_handler.py
panh99 Aug 14, 2023
13e5848
update
panh99 Aug 15, 2023
960303e
Update comments
panh99 Aug 15, 2023
18933d2
separate consts for secagg+ to a new file
panh99 Aug 15, 2023
2a6d6f6
add a comment
panh99 Aug 15, 2023
51811ea
Merge branch 'main' into secaggplus
danieljanes Aug 16, 2023
4ebac72
fix a typo in comments
panh99 Aug 16, 2023
9442baa
Merge branch 'secaggplus' of https://github.com/adap/flower into seca…
panh99 Aug 16, 2023
dbd1a71
Merge branch 'secaggplus' into secaggplus-example
panh99 Aug 16, 2023
79bc48e
temp update
panh99 Aug 17, 2023
a4c7986
fix a bug in serde
panh99 Aug 17, 2023
5c38ad7
Merge branch 'fix-serde' into secaggplus-example
panh99 Aug 17, 2023
8798c4c
update secagg-mt-pytorch using secaggplus
panh99 Aug 17, 2023
c536c85
Merge branch 'main' into secaggplus-example
danieljanes Aug 22, 2023
820d6f4
update example
panh99 Aug 23, 2023
9350f42
update print
panh99 Aug 23, 2023
4d640b4
Merge branch 'main' into secaggplus-example
panh99 Aug 23, 2023
5b4a88c
formatting
panh99 Aug 23, 2023
28780c7
update readme
panh99 Aug 23, 2023
fff42c6
rename folder
panh99 Aug 23, 2023
accb071
rename
panh99 Aug 23, 2023
fdc82e7
Update examples/secaggplus-mt/client.py
danieljanes Aug 24, 2023
c714dd2
Update examples/secaggplus-mt/client.py
danieljanes Aug 24, 2023
15c6fa1
Update examples/secaggplus-mt/driver.py
danieljanes Aug 24, 2023
f705e87
Update examples/secaggplus-mt/workflows.py
danieljanes Aug 24, 2023
6ef7050
Update examples/secaggplus-mt/workflows.py
danieljanes Aug 24, 2023
383c04b
Update examples/secaggplus-mt/workflows.py
danieljanes Aug 24, 2023
4b72885
update example
panh99 Aug 24, 2023
e2b5b70
merge from upstream
panh99 Aug 24, 2023
058fba2
Merge branch 'main' into secaggplus-example
danieljanes Aug 24, 2023
31bb908
Add readme
panh99 Aug 24, 2023
42e58cc
Merge branch 'secaggplus-example' of https://github.com/adap/flower i…
panh99 Aug 24, 2023
ed2e957
Update examples/secaggplus-mt/client.py
danieljanes Aug 24, 2023
b30cb80
add files for poetry and pip
panh99 Aug 24, 2023
f54631d
Merge branch 'secaggplus-example' of https://github.com/adap/flower i…
panh99 Aug 24, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions examples/secaggplus-mt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Secure Aggregation with Driver API

This example contains highly experimental code. Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart_pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced_pytorch)) to learn how to use Flower with PyTorch.

## Installing Dependencies

Project dependencies (such as and `flwr`) are defined in `pyproject.toml`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.

### Poetry

```shell
poetry install
poetry shell
```

Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command:

```shell
poetry run python3 -c "import flwr"
```

### pip

Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt.

```shell
pip install -r requirements.txt
```

If you don't see any errors you're good to go!

## Run with Driver API

```bash
./run.sh
```
35 changes: 35 additions & 0 deletions examples/secaggplus-mt/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import time

import numpy as np

import flwr as fl
from flwr.common import Status, FitIns, FitRes, Code
from flwr.common.parameter import ndarrays_to_parameters
from flwr.client.secure_aggregation import SecAggPlusHandler


# Define Flower client with the SecAgg+ protocol
class FlowerClient(fl.client.Client, SecAggPlusHandler):
def fit(self, fit_ins: FitIns) -> FitRes:
ret_vec = [np.ones(3)]
ret = FitRes(
status=Status(code=Code.OK, message="Success"),
parameters=ndarrays_to_parameters(ret_vec),
num_examples=1,
metrics={},
)
# Force a significant delay for testing purposes
if self._shared_state.sid == 0:
print(f"Client {self._shared_state.sid} dropped for testing purposes.")
time.sleep(4)
return ret
print(f"Client {self._shared_state.sid} uploading {ret_vec[0]}...")
return ret


# Start Flower client
fl.client.start_client(
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
server_address="0.0.0.0:9092",
client=FlowerClient(),
transport="grpc-rere",
)
200 changes: 200 additions & 0 deletions examples/secaggplus-mt/driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import random
import time
from typing import Dict, List, Tuple

import numpy as np
from workflows import get_workflow_factory

from flwr.common import Metrics, ndarrays_to_parameters
from flwr.driver import Driver
from flwr.proto import driver_pb2, node_pb2, task_pb2
from flwr.server import History


# Convert instruction/result dict to/from list of TaskIns/TaskRes
def task_dict_to_task_ins_list(
task_dict: Dict[int, task_pb2.Task]
) -> List[task_pb2.TaskIns]:
def merge(_task: task_pb2.Task, _merge_task: task_pb2.Task) -> task_pb2.Task:
_task.MergeFrom(_merge_task)
return _task

return [
task_pb2.TaskIns(
task_id="", # Do not set, will be created and set by the DriverAPI
group_id="",
workload_id="",
task=merge(
task,
task_pb2.Task(
producer=node_pb2.Node(
node_id=0,
anonymous=True,
),
consumer=node_pb2.Node(
node_id=sampled_node_id,
# Must be False for this Secure Aggregation example
anonymous=False,
),
),
),
)
for sampled_node_id, task in task_dict.items()
]


def task_res_list_to_task_dict(
task_res_list: List[task_pb2.TaskRes],
) -> Dict[int, task_pb2.Task]:
return {task_res.task.producer.node_id: task_res.task for task_res in task_res_list}


# Define metric aggregation function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
examples = [num_examples for num_examples, _ in metrics]

# Multiply accuracy of each client by number of examples used
train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
train_accuracies = [
num_examples * m["train_accuracy"] for num_examples, m in metrics
]
val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics]
val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics]

# Aggregate and return custom metric (weighted average)
return {
"train_loss": sum(train_losses) / sum(examples),
"train_accuracy": sum(train_accuracies) / sum(examples),
"val_loss": sum(val_losses) / sum(examples),
"val_accuracy": sum(val_accuracies) / sum(examples),
}


# -------------------------------------------------------------------------- Driver SDK
driver = Driver(driver_service_address="0.0.0.0:9091", certificates=None)
# -------------------------------------------------------------------------- Driver SDK

anonymous_client_nodes = False
num_client_nodes_per_round = 5
sleep_time = 0.5
time_out = 3.9
num_rounds = 3
parameters = ndarrays_to_parameters([np.ones(3)])
wf_factory = get_workflow_factory()

# -------------------------------------------------------------------------- Driver SDK
driver.connect()
# -------------------------------------------------------------------------- Driver SDK

history = History()
for server_round in range(num_rounds):
print(f"Commencing server round {server_round + 1}")

# List of sampled node IDs in this round
sampled_node_ids: List[int] = []

# Sample node ids
if anonymous_client_nodes:
# If we're working with anonymous clients, we don't know their identities, and
# we don't know how many of them we have. We, therefore, have to assume that
# enough anonymous client nodes are available or become available over time.
#
# To schedule a TaskIns for an anonymous client node, we set the node_id to 0
# (and `anonymous` to True)
# Here, we create an array with only zeros in it:
sampled_node_ids = [0] * num_client_nodes_per_round
else:
# If our client nodes have identiy (i.e., they are not anonymous), we can get
# those IDs from the Driver API using `get_nodes`. If enough clients are
# available via the Driver API, we can select a subset by taking a random
# sample.
#
# The Driver API might not immediately return enough client node IDs, so we
# loop and wait until enough client nodes are available.
while True:
# Get a list of node ID's from the server
get_nodes_req = driver_pb2.GetNodesRequest()

# ---------------------------------------------------------------------- Driver SDK
get_nodes_res: driver_pb2.GetNodesResponse = driver.get_nodes(
req=get_nodes_req
)
# ---------------------------------------------------------------------- Driver SDK

all_node_ids: List[int] = get_nodes_res.node_ids

if len(all_node_ids) >= num_client_nodes_per_round:
# Sample client nodes
sampled_node_ids = random.sample(
all_node_ids, num_client_nodes_per_round
)
break

time.sleep(3)

# Log sampled node IDs
time.sleep(sleep_time)

workflow = wf_factory(parameters, sampled_node_ids)
node_messages = None

while True:
try:
instructions: Dict[int, task_pb2.Task] = workflow.send(node_messages)
next(workflow)
except StopIteration:
break
# Schedule a task for all sampled nodes
task_ins_list: List[task_pb2.TaskIns] = task_dict_to_task_ins_list(instructions)

push_task_ins_req = driver_pb2.PushTaskInsRequest(task_ins_list=task_ins_list)

# ---------------------------------------------------------------------- Driver SDK
push_task_ins_res: driver_pb2.PushTaskInsResponse = driver.push_task_ins(
req=push_task_ins_req
)
# ---------------------------------------------------------------------- Driver SDK

time.sleep(sleep_time)

# Wait for results, ignore empty task_ids
start_time = time.time()
task_ids: List[str] = [
task_id for task_id in push_task_ins_res.task_ids if task_id != ""
]
all_task_res: List[task_pb2.TaskRes] = []
while True:
if time.time() - start_time >= time_out:
break
pull_task_res_req = driver_pb2.PullTaskResRequest(
node=node_pb2.Node(node_id=0, anonymous=True),
task_ids=task_ids,
)

# ------------------------------------------------------------------ Driver SDK
pull_task_res_res: driver_pb2.PullTaskResResponse = driver.pull_task_res(
req=pull_task_res_req
)
# ------------------------------------------------------------------ Driver SDK

task_res_list: List[task_pb2.TaskRes] = pull_task_res_res.task_res_list

time.sleep(sleep_time)

all_task_res += task_res_list
if len(all_task_res) == len(task_ids):
break

# Collect correct results
node_messages = task_res_list_to_task_dict(
[res for res in all_task_res if res.task.HasField("sa")]
)
workflow.close()

# Slow down the start of the next round
time.sleep(sleep_time)

# -------------------------------------------------------------------------- Driver SDK
driver.disconnect()
# -------------------------------------------------------------------------- Driver SDK
print("Driver disconnected")
13 changes: 13 additions & 0 deletions examples/secaggplus-mt/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "secaggplus-mt"
version = "0.1.0"
description = "Secure Aggregation with Driver API"
authors = ["The Flower Authors <hello@flower.dev>"]

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
flwr-nightly = { version = "^1.5.0.dev20230629", extras = ["simulation", "rest"] }
1 change: 1 addition & 0 deletions examples/secaggplus-mt/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
flwr-nightly[simulation,rest]
35 changes: 35 additions & 0 deletions examples/secaggplus-mt/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash
# Kill any currently running client.py processes
pkill -f 'python client.py'

# Kill any currently running flower-server processes with --grpc-rere option
pkill -f 'flower-server --grpc-rere'

# Start the flower server
echo "Starting flower server in background..."
flower-server --grpc-rere > /dev/null 2>&1 &
sleep 2

# Number of client processes to start
N=5 # Replace with your desired value

echo "Starting $N clients in background..."

# Start N client processes
for i in $(seq 1 $N)
do
python client.py > /dev/null 2>&1 &
# python client.py &
sleep 0.1
done

echo "Starting driver..."
python driver.py

echo "Clearing background processes..."

# Kill any currently running client.py processes
pkill -f 'python client.py'

# Kill any currently running flower-server processes with --grpc-rere option
pkill -f 'flower-server --grpc-rere'
Loading
Loading