Skip to content

Commit

Permalink
Create StateFactory (#1693)
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Feb 24, 2023
1 parent 174a5a9 commit 09510f1
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 38 deletions.
43 changes: 28 additions & 15 deletions src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@
)
from flwr.server.history import History
from flwr.server.server import Server
from flwr.server.state import InMemoryState, State
from flwr.server.state import StateFactory
from flwr.server.strategy import FedAvg, Strategy

ADDRESS_DRIVER_API = "[::]:9091"
ADDRESS_FLEET_API_GRPC = "[::]:9092"

DATABASE = ":flwr-in-memory-state:"


@dataclass
class ServerConfig:
Expand Down Expand Up @@ -217,13 +219,13 @@ def run_driver_api() -> None:
event(EventType.RUN_DRIVER_API_ENTER)
args = _parse_args_driver()

# Init state
state = InMemoryState()
# Initialize StateFactory
state_factory = StateFactory(args.database)

# Start server
grpc_server: grpc.Server = _run_driver_api_grpc(
address=args.driver_api_address,
state=state,
state_factory=state_factory,
)

# Graceful shutdown
Expand All @@ -243,13 +245,13 @@ def run_fleet_api() -> None:
event(EventType.RUN_FLEET_API_ENTER)
args = _parse_args_fleet()

# Init state
state = InMemoryState()
# Initialize StateFactory
state_factory = StateFactory(args.database)

# Start server
grpc_server: grpc.Server = _run_fleet_api_grpc_bidi(
address=args.fleet_api_address,
state=state,
state_factory=state_factory,
)

# Graceful shutdown
Expand All @@ -269,19 +271,19 @@ def run_server() -> None:
event(EventType.RUN_SERVER_ENTER)
args = _parse_args()

# Shared State
state = InMemoryState()
# Initialize StateFactory
state_factory = StateFactory(args.database)

# Start Driver API
driver_server: grpc.Server = _run_driver_api_grpc(
address=args.driver_api_address,
state=state,
state_factory=state_factory,
)

# Start Fleet API
fleet_server: grpc.Server = _run_fleet_api_grpc_bidi(
address=args.fleet_api_address,
state=state,
state_factory=state_factory,
)

# Graceful shutdown
Expand Down Expand Up @@ -340,13 +342,13 @@ def graceful_exit_handler( # type: ignore

def _run_driver_api_grpc(
address: str,
state: State,
state_factory: StateFactory,
) -> grpc.Server:
"""Run Driver API (gRPC, request-response)."""

# Create Driver API gRPC server
driver_servicer: grpc.Server = DriverServicer(
state=state,
state_factory=state_factory,
)
driver_add_servicer_to_server_fn = add_DriverServicer_to_server
driver_grpc_server = generic_create_grpc_server(
Expand All @@ -364,13 +366,13 @@ def _run_driver_api_grpc(

def _run_fleet_api_grpc_bidi(
address: str,
state: State,
state_factory: StateFactory,
) -> grpc.Server:
"""Run Fleet API (gRPC, bidirectional streaming)."""

# DriverClientManager
driver_client_manager = DriverClientManager(
state=state,
state_factory=state_factory,
)

# Create (legacy) Fleet API gRPC server
Expand All @@ -397,6 +399,7 @@ def _parse_args_driver() -> argparse.Namespace:
description="Start Flower server (Driver API)",
)

_add_args_common(parser=parser)
_add_arg_driver_api_address(parser=parser)

return parser.parse_args()
Expand All @@ -408,6 +411,7 @@ def _parse_args_fleet() -> argparse.Namespace:
description="Start Flower server (Fleet API)",
)

_add_args_common(parser=parser)
_add_arg_fleet_api_address(parser=parser)

return parser.parse_args()
Expand All @@ -419,12 +423,21 @@ def _parse_args() -> argparse.Namespace:
description="Start Flower server (Driver API and Fleet API)",
)

_add_args_common(parser=parser)
_add_arg_driver_api_address(parser=parser)
_add_arg_fleet_api_address(parser=parser)

return parser.parse_args()


def _add_args_common(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--database",
help=f"Flower server database. Default: {DATABASE}",
default=DATABASE,
)


def _add_arg_driver_api_address(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--driver-api-address",
Expand Down
26 changes: 14 additions & 12 deletions src/py/flwr/server/driver/driver_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,23 @@
PushTaskInsResponse,
)
from flwr.proto.task_pb2 import TaskRes
from flwr.server.state import State
from flwr.server.state import State, StateFactory
from flwr.server.utils.validator import validate_task_ins_or_res


class DriverServicer(driver_pb2_grpc.DriverServicer):
"""Driver API servicer."""

def __init__(
self,
state: State,
) -> None:
self.state = state
def __init__(self, state_factory: StateFactory) -> None:
self.state_factory = state_factory

def GetNodes(
self, request: GetNodesRequest, context: grpc.ServicerContext
) -> GetNodesResponse:
"""Get available nodes."""
log(INFO, "DriverServicer.GetNodes")
all_ids: Set[int] = self.state.get_nodes()
state: State = self.state_factory.state()
all_ids: Set[int] = state.get_nodes()
return GetNodesResponse(node_ids=list(all_ids))

def PushTaskIns(
Expand All @@ -65,10 +63,13 @@ def PushTaskIns(
validation_errors = validate_task_ins_or_res(task_ins)
_raise_if(bool(validation_errors), ", ".join(validation_errors))

# Init state
state: State = self.state_factory.state()

# Store each TaskIns
task_ids: List[Optional[UUID]] = []
for task_ins in request.task_ins_list:
task_id: Optional[UUID] = self.state.store_task_ins(task_ins=task_ins)
task_id: Optional[UUID] = state.store_task_ins(task_ins=task_ins)
task_ids.append(task_id)

return PushTaskInsResponse(
Expand All @@ -84,6 +85,9 @@ def PullTaskRes(
# Convert each task_id str to UUID
task_ids: Set[UUID] = {UUID(task_id) for task_id in request.task_ids}

# Init state
state: State = self.state_factory.state()

# Register callback
def on_rpc_done() -> None:
log(INFO, "DriverServicer.PullTaskRes callback: delete TaskIns/TaskRes")
Expand All @@ -94,14 +98,12 @@ def on_rpc_done() -> None:
return

# Delete delivered TaskIns and TaskRes
self.state.delete_tasks(task_ids=task_ids)
state.delete_tasks(task_ids=task_ids)

context.add_callback(on_rpc_done)

# Read from state
task_res_list: List[TaskRes] = self.state.get_task_res(
task_ids=task_ids, limit=None
)
task_res_list: List[TaskRes] = state.get_task_res(task_ids=task_ids, limit=None)

context.set_code(grpc.StatusCode.OK)
return PullTaskResResponse(task_res_list=task_res_list)
Expand Down
14 changes: 8 additions & 6 deletions src/py/flwr/server/grpc_server/driver_client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.criterion import Criterion
from flwr.server.state import State
from flwr.server.state import State, StateFactory

from .ins_scheduler import InsScheduler


class DriverClientManager(ClientManager):
"""Provides a pool of available clients."""

def __init__(self, state: State) -> None:
def __init__(self, state_factory: StateFactory) -> None:
self._cv = threading.Condition()
self.nodes: Dict[str, Tuple[int, InsScheduler]] = {}
self.state = state
self.state_factory = state_factory

def __len__(self) -> int:
"""Return the number of available clients.
Expand Down Expand Up @@ -76,12 +76,13 @@ def register(self, client: ClientProxy) -> bool:
client.node_id = random_node_id

# Register node_id in with State
self.state.register_node(node_id=random_node_id)
state: State = self.state_factory.state()
state.register_node(node_id=random_node_id)

# Create and start the instruction scheduler
ins_scheduler = InsScheduler(
client_proxy=client,
state=self.state,
state_factory=self.state_factory,
)
ins_scheduler.start()

Expand All @@ -108,7 +109,8 @@ def unregister(self, client: ClientProxy) -> None:
ins_scheduler.stop()

# Unregister node_id in with State
self.state.unregister_node(node_id=node_id)
state: State = self.state_factory.state()
state.unregister_node(node_id=node_id)

with self._cv:
self._cv.notify_all()
Expand Down
12 changes: 7 additions & 5 deletions src/py/flwr/server/grpc_server/ins_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
from flwr.server.client_proxy import ClientProxy
from flwr.server.state import State
from flwr.server.state import State, StateFactory


class InsScheduler:
"""Schedule ClientProxy calls on a background thread."""

def __init__(self, client_proxy: ClientProxy, state: State):
def __init__(self, client_proxy: ClientProxy, state_factory: StateFactory):
self.client_proxy = client_proxy
self.state = state
self.state_factory = state_factory
self.worker_thread: Optional[threading.Thread] = None
self.shared_memory_state = {"stop": False}

Expand All @@ -45,7 +45,7 @@ def start(self) -> None:
args=(
self.client_proxy,
self.shared_memory_state,
self.state,
self.state_factory,
),
)
self.worker_thread.start()
Expand All @@ -64,10 +64,12 @@ def stop(self) -> None:
def _worker(
client_proxy: ClientProxy,
shared_memory_state: Dict[str, bool],
state: State,
state_factory: StateFactory,
) -> None:
"""Sequentially call ClientProxy methods to process outstanding tasks."""
log(DEBUG, "Worker for node %i started", client_proxy.node_id)

state: State = state_factory.state()
while not shared_memory_state["stop"]:
log(DEBUG, "Worker for node %i checking state", client_proxy.node_id)

Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/server/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
from .in_memory_state import InMemoryState as InMemoryState
from .sqlite_state import SqliteState as SqliteState
from .state import State as State
from .state_factory import StateFactory as StateFactory

__all__ = [
"InMemoryState",
"SqliteState",
"State",
"StateFactory",
]
49 changes: 49 additions & 0 deletions src/py/flwr/server/state/state_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2022 Adap GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Factory class that creates State instances."""


from logging import DEBUG
from typing import Optional

from flwr.common.logger import log

from .in_memory_state import InMemoryState
from .sqlite_state import SqliteState
from .state import State


class StateFactory:
"""Factory class that creates State instances."""

def __init__(self, database: str) -> None:
self.database = database
self.state_instance: Optional[State] = None

def state(self) -> State:
"""Return a State instance and create it, if necessary."""

# InMemoryState
if self.database == ":flwr-in-memory-state:":
if self.state_instance is None:
self.state_instance = InMemoryState()
log(DEBUG, "Using InMemoryState")
return self.state_instance

# SqliteState
state = SqliteState(self.database)
state.initialize()
log(DEBUG, "Using SqliteState")
return state

0 comments on commit 09510f1

Please sign in to comment.