Skip to content

Commit

Permalink
Create new REST API for client nodes (#1594)
Browse files Browse the repository at this point in the history
Co-authored-by: Pedro Porto Buarque de Gusmão <pedropgusmao@gmail.com>
Co-authored-by: Taner Topal <taner@adap.com>
  • Loading branch information
3 people authored Feb 26, 2023
1 parent 09510f1 commit 1757d5a
Show file tree
Hide file tree
Showing 12 changed files with 690 additions and 35 deletions.
3 changes: 2 additions & 1 deletion examples/mt-pytorch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def evaluate(self, parameters, config):

# Start Flower client
fl.client.start_numpy_client(
server_address="[::]:9092",
server_address="0.0.0.0:9093",
client=FlowerClient(),
use_rest=True,
)
4 changes: 2 additions & 2 deletions examples/mt-pytorch/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from task import Net, get_parameters, set_parameters

# -------------------------------------------------------------------------- Driver SDK
driver = Driver(driver_service_address="[::]:9091", certificates=None)
driver = Driver(driver_service_address="0.0.0.0:9091", certificates=None)
# -------------------------------------------------------------------------- Driver SDK

anonymous_client_nodes = False
anonymous_client_nodes = True
num_client_nodes_per_round = 1
sleep_time = 1
num_rounds = 1
Expand Down
2 changes: 1 addition & 1 deletion examples/mt-pytorch/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = ["The Flower Authors <hello@flower.dev>"]

[tool.poetry.dependencies]
python = "^3.7"
flwr = { path = "../../", develop = true, extras = ["simulation"] }
flwr = { path = "../../", develop = true, extras = ["simulation", "rest"] }
torch = "^1.12.0"
torchvision = "^0.13.0"
tqdm = "^4.63.0"
4 changes: 2 additions & 2 deletions examples/mt-pytorch/run-grpc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/
python -c "from torchvision.datasets import CIFAR10; CIFAR10('./data', download=True)"

echo "Starting server"
flower-server &
flower-server --rest &
sleep 3 # Sleep for 3s to give the server enough time to start

for i in `seq 0 2`; do
for i in `seq 0 0`; do
echo "Starting client $i"
python client.py &
done
Expand Down
20 changes: 19 additions & 1 deletion src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from .numpy_client import has_fit as numpyclient_has_fit
from .numpy_client import has_get_parameters as numpyclient_has_get_parameters
from .numpy_client import has_get_properties as numpyclient_has_get_properties
from .rest_client.connection import http_request_response

EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT = """
NumPyClient.fit did not return a tuple with 3 elements.
Expand Down Expand Up @@ -86,6 +87,7 @@ def start_client(
client: Client,
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
root_certificates: Optional[bytes] = None,
use_rest: bool = False,
) -> None:
"""Start a Flower Client which connects to a gRPC server.
Expand All @@ -107,6 +109,10 @@ class `flwr.client.Client`.
The PEM-encoded root certificates as a byte string. If provided, a secure
connection using the certificates will be established to a
SSL-enabled Flower server.
use_rest: bool (default: False)
Defines whether or not the client is interacting with the server using the
experimental REST API. This feature is experimental, it might change
considerably in future versions of Flower.
Returns
-------
Expand All @@ -130,11 +136,14 @@ class `flwr.client.Client`.
>>> root_certificates=Path("/crts/root.pem").read_bytes(),
>>> )
"""

event(EventType.START_CLIENT_ENTER)

# Use either gRPC bidirectional streaming or REST request/response
connection = http_request_response if use_rest else grpc_connection
while True:
sleep_duration: int = 0
with grpc_connection(
with connection(
server_address,
max_message_length=grpc_max_message_length,
root_certificates=root_certificates,
Expand All @@ -143,6 +152,9 @@ class `flwr.client.Client`.

while True:
server_message = receive()
if server_message is None:
time.sleep(3) # Wait for 3s before asking again
continue
client_message, sleep_duration, keep_going = handle(
client, server_message
)
Expand All @@ -169,6 +181,7 @@ def start_numpy_client(
client: NumPyClient,
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
root_certificates: Optional[bytes] = None,
use_rest: bool = False,
) -> None:
"""Start a Flower NumPyClient which connects to a gRPC server.
Expand All @@ -191,6 +204,10 @@ def start_numpy_client(
The PEM-encoded root certificates a byte string. If provided, a secure
connection using the certificates will be established to a
SSL-enabled Flower server.
use_rest: bool (default: False)
Defines whether or not the client is interacting with the server using the
experimental REST API. This feature is experimental, it might be change
considerably in future versions of Flower.
Examples
--------
Expand All @@ -217,6 +234,7 @@ def start_numpy_client(
client=_wrap_numpy_client(client=client),
grpc_max_message_length=grpc_max_message_length,
root_certificates=root_certificates,
use_rest=use_rest,
)


Expand Down
46 changes: 46 additions & 0 deletions src/py/flwr/client/message_handler/task_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2023 Adap GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Task handling."""


from typing import Optional, Tuple

from flwr.proto.fleet_pb2 import PullTaskInsResponse
from flwr.proto.task_pb2 import TaskIns
from flwr.proto.transport_pb2 import ServerMessage


def get_server_message(
pull_task_ins_response: PullTaskInsResponse,
) -> Optional[Tuple[TaskIns, ServerMessage]]:
"""Get the first ServerMessage, if available."""

# Extract a single ServerMessage from the response, if possible
if len(pull_task_ins_response.task_ins_list) == 0:
return None

# Only evaluate the first message
task_ins: TaskIns = pull_task_ins_response.task_ins_list[0]

# Discard the message if it is not in
# {GetPropertiesIns, GetParametersIns, FitIns, EvaluateIns}
if (
not task_ins.HasField("task")
or not task_ins.task.HasField("legacy_server_message")
or task_ins.task.legacy_server_message.WhichOneof("msg") == "reconnect_ins"
):
return None

return task_ins, task_ins.task.legacy_server_message
128 changes: 128 additions & 0 deletions src/py/flwr/client/message_handler/task_handler_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright 2023 Adap GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for module task_handler."""


from flwr.client.message_handler.task_handler import get_server_message
from flwr.proto.fleet_pb2 import PullTaskInsResponse, Reconnect
from flwr.proto.task_pb2 import Task, TaskIns
from flwr.proto.transport_pb2 import ServerMessage


def test_get_server_message_empty() -> None:
"""Test get_server_message."""

# Prepare
res = PullTaskInsResponse(reconnect=None, task_ins_list=[])

# Execute
actual = get_server_message(res)

# Assert
assert actual is None


def test_get_server_message_reconnect() -> None:
"""Test get_server_message."""

# Prepare
res = PullTaskInsResponse(reconnect=Reconnect(reconnect=42), task_ins_list=[])

# Execute
actual = get_server_message(res)

# Assert
assert actual is None


def test_get_server_message_none_task() -> None:
"""Test get_server_message."""

# Prepare
res = PullTaskInsResponse(reconnect=None, task_ins_list=[TaskIns(task=None)])

# Execute
actual = get_server_message(res)

# Assert
assert actual is None


def test_get_server_message_none_legacy() -> None:
"""Test get_server_message."""

# Prepare
res = PullTaskInsResponse(
reconnect=None, task_ins_list=[TaskIns(task=Task(legacy_server_message=None))]
)

# Execute
actual = get_server_message(res)

# Assert
assert actual is None


def test_get_server_message_legacy_reconnect() -> None:
"""Test get_server_message."""

# Prepare
res = PullTaskInsResponse(
reconnect=None,
task_ins_list=[
TaskIns(
task=Task(
legacy_server_message=ServerMessage(
reconnect_ins=ServerMessage.ReconnectIns(seconds=3)
)
)
)
],
)

# Execute
actual = get_server_message(res)

# Assert
assert actual is None


def test_get_server_message_legacy_valid() -> None:
"""Test get_server_message."""

# Prepare
expected = TaskIns(
task=Task(
legacy_server_message=ServerMessage(
get_properties_ins=ServerMessage.GetPropertiesIns()
)
)
)
res = PullTaskInsResponse(
reconnect=None,
task_ins_list=[expected],
)

# Execute
actual = get_server_message(res)

# Assert
assert actual is not None
actual_task_ins, actual_server_message = actual
assert actual_task_ins == expected

# pylint: disable=no-member
assert actual_server_message == expected.task.legacy_server_message
# pylint: enable=no-member
15 changes: 15 additions & 0 deletions src/py/flwr/client/rest_client/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2020 Adap GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Client-side part of the REST transport layer."""
Loading

0 comments on commit 1757d5a

Please sign in to comment.