Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add server auth interceptor #2948

Merged
merged 102 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
d0d4ecd
Add authentication state and test
danielnugraha Feb 10, 2024
91a2f18
Fix isort
danielnugraha Feb 10, 2024
db16c10
Fix isort
danielnugraha Feb 10, 2024
28876bc
Run format.sh
danielnugraha Feb 10, 2024
42a7d38
Add init.py
danielnugraha Feb 10, 2024
8ec63c9
Fix line too long
danielnugraha Feb 10, 2024
8f04e25
Fix line too long
danielnugraha Feb 10, 2024
e8813fc
Fix line too long
danielnugraha Feb 10, 2024
d9f3fb0
Fix subclassing
danielnugraha Feb 10, 2024
caf6695
Fix subclassing
danielnugraha Feb 10, 2024
fa217ae
Fix subclassing
danielnugraha Feb 10, 2024
6edddd6
Fix subclassing
danielnugraha Feb 10, 2024
8bb15a5
Fix subclassing
danielnugraha Feb 10, 2024
c5bac4f
fixes
jafermarq Feb 11, 2024
c856b7c
Fix state tests
danielnugraha Feb 11, 2024
4758507
Fix too broad exception
danielnugraha Feb 11, 2024
e666da5
Add sqlite auth state test
danielnugraha Feb 11, 2024
151a619
Merge remote-tracking branch 'origin' into add-auth-state
danielnugraha Feb 11, 2024
5c49a55
Add server interceptor
danielnugraha Feb 12, 2024
e443bf9
Merge remote-tracking branch 'origin' into add-server-auth-interceptor
danielnugraha Feb 14, 2024
986961e
Move state to superlink
danielnugraha Feb 14, 2024
dddbbc9
Move state to superlink
danielnugraha Feb 14, 2024
fbbcb2a
Fix server interceptor
danielnugraha Feb 15, 2024
77e5c3c
Fix authentication state
danielnugraha Feb 15, 2024
21e590a
Add symmetric encryption test
danielnugraha Feb 15, 2024
6823c83
Add symmetric encryption test
danielnugraha Feb 15, 2024
9df829c
Format code
danielnugraha Feb 15, 2024
d250945
Make tests pass
danielnugraha Feb 15, 2024
de0f041
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Feb 15, 2024
2f61623
Revert commit to only include auth state
danielnugraha Feb 15, 2024
2f7aa48
Remove logging messages
danielnugraha Feb 16, 2024
d96879c
Merge remote-tracking branch 'refs/remotes/origin/add-server-auth-int…
danielnugraha Feb 16, 2024
f9a21b4
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Feb 21, 2024
6ee8a61
Update server interceptor
danielnugraha Feb 22, 2024
041482e
Merge from add-auth-cli
danielnugraha Feb 22, 2024
781796e
Docstring changes
danielnugraha Feb 28, 2024
651f665
Merge branch 'main' into add-auth-state
danielnugraha Feb 28, 2024
013582f
Merge auth state
danielnugraha Feb 28, 2024
700f6dd
Merge main
danielnugraha Feb 28, 2024
ab5317f
Fix merge conflict interceptors gone
danielnugraha Feb 28, 2024
5075ab7
Fix too many instances
danielnugraha Feb 28, 2024
7f49d81
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Feb 28, 2024
d6238ee
Fix imports merge conflict
danielnugraha Feb 29, 2024
578fd96
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Feb 29, 2024
2de5fd5
Add docstring to interceptor
danielnugraha Feb 29, 2024
9b62c4f
Format
danielnugraha Mar 1, 2024
e2ad1ef
Implement feedback
danielnugraha Mar 1, 2024
811d8e8
Merge remote-tracking branch 'origin' into add-server-auth-interceptor
danielnugraha Mar 1, 2024
7a5b6f0
Fix merge conflicts
danielnugraha Apr 3, 2024
28afce5
Format
danielnugraha Apr 4, 2024
78a6697
Fix merge conflicts
danielnugraha Apr 4, 2024
7c098d2
Fix error
danielnugraha Apr 4, 2024
100eadb
Fix error
danielnugraha Apr 4, 2024
5c9f6c3
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Apr 4, 2024
974d2b2
Merge branch 'main' into add-auth-state
danielnugraha Apr 4, 2024
f10333d
Merge branch 'main' into add-auth-state
danielnugraha Apr 15, 2024
74f7036
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Apr 15, 2024
d2b60f6
Merge remote-tracking branch 'origin' into add-auth-state
danielnugraha Apr 18, 2024
94ee2fe
Add lock to write operations
danielnugraha Apr 18, 2024
08c07b5
Merge branch 'add-auth-state' into add-server-auth-interceptor
danielnugraha Apr 18, 2024
0fe2744
Fix docstring
danielnugraha Apr 21, 2024
cc682f3
Fix merge conflict
danielnugraha Apr 24, 2024
54b7afa
Format
danielnugraha Apr 24, 2024
4ab971b
Format
danielnugraha Apr 24, 2024
054bd04
Add more tests
danielnugraha Apr 24, 2024
aac42ab
Add failure tests
danielnugraha Apr 24, 2024
c5b3e46
Add failure tests
danielnugraha Apr 24, 2024
e76b937
Format
danielnugraha Apr 24, 2024
846373f
Fix docstring
danielnugraha Apr 24, 2024
a61892e
Format prepare, execute & assert
danielnugraha Apr 24, 2024
5212653
Merge from main
danielnugraha Apr 24, 2024
8b58c5e
Add get run
danielnugraha Apr 24, 2024
bd6163b
Dynamically generate ssh key
danielnugraha Apr 24, 2024
1dc292a
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Apr 24, 2024
6a02ba9
Encode only once
danielnugraha Apr 24, 2024
e709e99
Merge remote-tracking branch 'refs/remotes/origin/add-server-auth-int…
danielnugraha Apr 24, 2024
d702924
Format
danielnugraha Apr 24, 2024
ac54694
Add get_run
danielnugraha Apr 24, 2024
5735549
Unindent function
danielnugraha Apr 24, 2024
1ca16ba
Format
danielnugraha Apr 24, 2024
5c2f805
Merge branch 'main' into add-server-auth-interceptor
danieljanes Apr 24, 2024
64ec5e8
Update src/py/flwr/server/app.py
danielnugraha Apr 24, 2024
9c03cd6
Update src/py/flwr/server/app.py
danielnugraha Apr 24, 2024
4779dc0
Implement review feedback
danielnugraha Apr 24, 2024
af03db1
Adapt error string
danielnugraha Apr 25, 2024
af1773b
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Apr 26, 2024
24b7182
Update src/py/flwr/server/app.py
danielnugraha Apr 26, 2024
16e20d4
Change data to maybe_keys
danielnugraha Apr 26, 2024
79728df
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Apr 26, 2024
a742d40
Implement feedback
danielnugraha Apr 26, 2024
461802a
Private key first then public
danielnugraha Apr 26, 2024
cc61c1c
Add context to message handler as comment
danielnugraha Apr 29, 2024
948d8a7
Remove state usage
danielnugraha Apr 29, 2024
9cde4a8
Fix server_test
danielnugraha Apr 29, 2024
bbab863
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Apr 29, 2024
4d2f17c
Apply suggestions from code review
danieljanes Apr 29, 2024
ff5f3b8
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Apr 29, 2024
b7357bc
Update src/py/flwr/server/superlink/fleet/grpc_rere/server_intercepto…
danieljanes Apr 29, 2024
499dbdc
Revert delete state
danielnugraha Apr 29, 2024
7e06ca1
Merge remote-tracking branch 'refs/remotes/origin/add-server-auth-int…
danielnugraha Apr 29, 2024
cd4451c
Rename message handler to method handler
danielnugraha Apr 29, 2024
1329328
Change Any to Response type
danielnugraha Apr 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import base64
from typing import Tuple, cast

from cryptography.exceptions import InvalidSignature
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives import hashes, hmac, serialization
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.kdf.hkdf import HKDF

Expand Down Expand Up @@ -98,3 +99,21 @@ def decrypt(key: bytes, ciphertext: bytes) -> bytes:
# The input key must be url safe
fernet = Fernet(key)
return fernet.decrypt(ciphertext)


def compute_hmac(key: bytes, message: bytes) -> bytes:
"""Compute hmac of a message using key as hash."""
computed_hmac = hmac.HMAC(key, hashes.SHA256())
computed_hmac.update(message)
return computed_hmac.finalize()


def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool:
"""Verify hmac of a message using key as hash."""
computed_hmac = hmac.HMAC(key, hashes.SHA256())
computed_hmac.update(message)
try:
computed_hmac.verify(hmac_value)
return True
except InvalidSignature:
return False
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Symmetric encryption tests."""


from .symmetric_encryption import (
compute_hmac,
generate_key_pairs,
generate_shared_key,
verify_hmac,
)


def test_generate_shared_key() -> None:
"""Test util function generate_shared_key."""
client_keys = generate_key_pairs()
server_keys = generate_key_pairs()

client_shared_secret = generate_shared_key(client_keys[0], server_keys[1])
server_shared_secret = generate_shared_key(server_keys[0], client_keys[1])

assert client_shared_secret == server_shared_secret


def test_hmac() -> None:
"""Test util function compute and verify hmac."""
client_keys = generate_key_pairs()
server_keys = generate_key_pairs()
client_shared_secret = generate_shared_key(client_keys[0], server_keys[1])
server_shared_secret = generate_shared_key(server_keys[0], client_keys[1])
message = b"Flower is the future of AI"

client_compute_hmac = compute_hmac(client_shared_secret, message)

assert verify_hmac(server_shared_secret, message, client_compute_hmac)
94 changes: 93 additions & 1 deletion src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@

import argparse
import asyncio
import csv
import importlib.util
import sys
import threading
from logging import ERROR, INFO, WARN
from os.path import isfile
from pathlib import Path
from typing import List, Optional, Tuple
from typing import List, Optional, Sequence, Set, Tuple

import grpc
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.serialization import (
load_ssh_private_key,
load_ssh_public_key,
)

from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
from flwr.common.address import parse_address
Expand All @@ -36,6 +42,9 @@
)
from flwr.common.exit_handlers import register_exit_handlers
from flwr.common.logger import log
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
public_key_to_bytes,
)
from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
add_FleetServicer_to_server,
)
Expand All @@ -44,6 +53,7 @@
from .history import History
from .server import Server, init_defaults, run_fl
from .server_config import ServerConfig
from .server_interceptor import AuthenticateServerInterceptor
from .strategy import Strategy
from .superlink.driver.driver_grpc import run_driver_api_grpc
from .superlink.fleet.grpc_bidi.grpc_server import (
Expand Down Expand Up @@ -352,10 +362,29 @@ def run_superlink() -> None:
sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
host, port, is_v6 = parsed_address
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"

data = _try_setup_client_authentication(args, certificates)
danielnugraha marked this conversation as resolved.
Show resolved Hide resolved
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
if data is not None:
(
client_public_keys,
server_public_key,
server_private_key,
) = data
interceptors = [
AuthenticateServerInterceptor(
state_factory,
client_public_keys,
server_private_key,
server_public_key,
)
]

fleet_server = _run_fleet_api_grpc_rere(
address=address,
state_factory=state_factory,
certificates=certificates,
interceptors=interceptors,
)
grpc_servers.append(fleet_server)
elif args.fleet_api_type == TRANSPORT_TYPE_VCE:
Expand Down Expand Up @@ -388,6 +417,59 @@ def run_superlink() -> None:
driver_server.wait_for_termination(timeout=1)


def _try_setup_client_authentication(
args: argparse.Namespace,
certificates: Optional[Tuple[bytes, bytes, bytes]],
) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePublicKey, ec.EllipticCurvePrivateKey]]:
if args.require_client_authentication:
if certificates is None:
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
sys.exit(
"Certificates are required to enable client authentication. "
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
"Please provide certificate paths with '--certificates' before "
"enabling '--require-client-authentication'."
)
client_keys_file_path = Path(args.require_client_authentication[0])
if not client_keys_file_path.exists():
sys.exit(
"Client public keys csv file are required for client authentication. "
"Please provide the csv file path containing known client public keys "
"to '--require-client-authentication'."
)
client_public_keys: Set[bytes] = set()
public_key = load_ssh_public_key(
Path(args.require_client_authentication[1]).read_bytes()
)
private_key = load_ssh_private_key(
Path(args.require_client_authentication[2]).read_bytes(),
None,
)
if not isinstance(public_key, ec.EllipticCurvePublicKey) or not isinstance(
private_key, ec.EllipticCurvePrivateKey
):
sys.exit(
"An eliptic curve public and private key pair is required for "
"client authentication. Please provide the file path containing "
"valid public and private key to '--require-client-authentication'."
)
server_public_key = public_key
server_private_key = private_key

with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile:
reader = csv.reader(csvfile)
for row in reader:
for element in row:
public_key = load_ssh_public_key(element.encode())
if isinstance(public_key, ec.EllipticCurvePublicKey):
client_public_keys.add(public_key_to_bytes(public_key))
return (
client_public_keys,
server_public_key,
server_private_key,
)
else:
return None
danielnugraha marked this conversation as resolved.
Show resolved Hide resolved


def _try_obtain_certificates(
args: argparse.Namespace,
) -> Optional[Tuple[bytes, bytes, bytes]]:
Expand Down Expand Up @@ -415,6 +497,7 @@ def _run_fleet_api_grpc_rere(
address: str,
state_factory: StateFactory,
certificates: Optional[Tuple[bytes, bytes, bytes]],
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
) -> grpc.Server:
"""Run Fleet API (gRPC, request-response)."""
# Create Fleet API gRPC server
Expand All @@ -427,6 +510,7 @@ def _run_fleet_api_grpc_rere(
server_address=address,
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
certificates=certificates,
interceptors=interceptors,
)

log(INFO, "Flower ECE: Starting Fleet API (gRPC-rere) on %s", address)
Expand Down Expand Up @@ -606,6 +690,14 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
"Flower will just create a state in memory.",
default=DATABASE,
)
parser.add_argument(
"--require-client-authentication",
nargs=3,
metavar=("CLIENT_KEYS", "SERVER_PUBLIC_KEY", "SERVER_PRIVATE_KEY"),
type=str,
help="Paths to .csv file containing list of known client public keys for "
"authentication, server public key, and server private key, in that order.",
danielnugraha marked this conversation as resolved.
Show resolved Hide resolved
)


def _add_args_driver_api(parser: argparse.ArgumentParser) -> None:
Expand Down
Loading