Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new fields to Metadata class #2961

Merged
merged 31 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
491b199
add new fields to metadata class
panh99 Feb 15, 2024
3a7c23f
Merge branch 'main' into expand-metadata
panh99 Feb 15, 2024
9558fce
add message validation and unittests
panh99 Feb 15, 2024
c94b1a2
Merge branch 'main' into expand-metadata
panh99 Feb 15, 2024
a68282f
Merge branch 'main' into expand-metadata
panh99 Feb 16, 2024
6dc4eb4
add doc string
panh99 Feb 16, 2024
aa23eb5
Merge remote-tracking branch 'refs/remotes/origin/expand-metadata' in…
panh99 Feb 16, 2024
a01aedf
rename task type to message type
panh99 Feb 16, 2024
1452b9d
Merge branch 'main' into expand-metadata
panh99 Feb 16, 2024
2fed770
Remove Enum type
panh99 Feb 16, 2024
9a46ce9
simplify MessageType
panh99 Feb 16, 2024
fa52b93
add setters
panh99 Feb 16, 2024
f01a878
fix a bug
panh99 Feb 16, 2024
d0619bd
update type hint
panh99 Feb 16, 2024
4c9be90
add a comment
panh99 Feb 16, 2024
05c94e9
Merge branch 'main' into expand-metadata
panh99 Feb 16, 2024
fa73481
use MessageType
panh99 Feb 16, 2024
4fd07b6
Merge remote-tracking branch 'refs/remotes/origin/expand-metadata' in…
panh99 Feb 16, 2024
6b95c4c
rm MessageType
panh99 Feb 17, 2024
f3d0341
change Message.content to property
panh99 Feb 17, 2024
cad3721
Merge branch 'main' into expand-metadata
panh99 Feb 17, 2024
0e90485
disgard invalid out messages
panh99 Feb 18, 2024
0f50c62
add check in receive()
panh99 Feb 18, 2024
f17b0e2
rm state in connection.py
panh99 Feb 18, 2024
45affcb
Revert "rm state in connection.py"
panh99 Feb 18, 2024
df976ce
move message validation to send()
panh99 Feb 18, 2024
e1a0824
Merge branch 'main' into expand-metadata
panh99 Feb 18, 2024
e49d40d
adapt to main
panh99 Feb 18, 2024
f6df048
Merge branch 'main' into expand-metadata
panh99 Feb 18, 2024
5ef8e9f
rm repetitive imports
panh99 Feb 18, 2024
c0620cb
Merge branch 'main' into expand-metadata
danieljanes Feb 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/py/flwr/client/grpc_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,11 @@ def receive() -> Message:
metadata=Metadata(
run_id=0,
message_id=str(uuid.uuid4()),
src_node_id=0,
dst_node_id=0,
reply_to_message="",
group_id="",
ttl="",
node_id=0,
message_type=message_type,
),
content=recordset,
Expand Down Expand Up @@ -205,7 +207,7 @@ def send(message: Message) -> None:
disconnect_res=ClientMessage.DisconnectRes(reason=reason)
)
else:
raise ValueError(f"Invalid task type: {message_type}")
raise ValueError(f"Invalid message type: {message_type}")

# Send ClientMessage proto
return queue.put(msg_proto, block=False)
Expand Down
8 changes: 6 additions & 2 deletions src/py/flwr/client/grpc_client/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@
metadata=Metadata(
run_id=0,
message_id="",
src_node_id=0,
dst_node_id=0,
reply_to_message="",
group_id="",
node_id=0,
ttl="",
message_type=MESSAGE_TYPE_GET_PROPERTIES,
),
Expand All @@ -59,8 +61,10 @@
metadata=Metadata(
run_id=0,
message_id="",
src_node_id=0,
dst_node_id=0,
reply_to_message="",
group_id="",
node_id=0,
ttl="",
message_type="reconnect",
),
Expand Down
56 changes: 28 additions & 28 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,17 @@


from contextlib import contextmanager
from copy import copy
from logging import DEBUG, ERROR
from pathlib import Path
from typing import Callable, Dict, Iterator, Optional, Tuple, Union, cast

from flwr.client.message_handler.task_handler import (
configure_task_res,
get_task_ins,
validate_task_ins,
validate_task_res,
)
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Message
from flwr.client.message_handler.message_handler import validate_out_message
from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
from flwr.common.grpc import create_channel
from flwr.common.logger import log, warn_experimental_feature
from flwr.common.message import Message, Metadata
from flwr.common.serde import message_from_taskins, message_to_taskres
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
Expand All @@ -41,7 +39,7 @@
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611

KEY_NODE = "node"
KEY_TASK_INS = "current_task_ins"
KEY_METADATA = "in_message_metadata"


def on_channel_state_change(channel_connectivity: str) -> None:
Expand Down Expand Up @@ -102,8 +100,8 @@ def grpc_request_response(
channel.subscribe(on_channel_state_change)
stub = FleetStub(channel)

# Necessary state to link TaskRes to TaskIns
state: Dict[str, Optional[TaskIns]] = {KEY_TASK_INS: None}
# Necessary state to validate messages to be sent
state: Dict[str, Optional[Metadata]] = {KEY_METADATA: None}

# Enable create_node and delete_node to store node
node_store: Dict[str, Optional[Node]] = {KEY_NODE: None}
Expand Down Expand Up @@ -149,45 +147,47 @@ def receive() -> Optional[Message]:
task_ins: Optional[TaskIns] = get_task_ins(response)

# Discard the current TaskIns if not valid
if task_ins is not None and not validate_task_ins(task_ins):
if task_ins is not None and not (
task_ins.task.consumer.node_id == node.node_id
and validate_task_ins(task_ins)
):
task_ins = None

# Remember `task_ins` until `task_res` is available
state[KEY_TASK_INS] = task_ins
# Construct the Message
in_message = message_from_taskins(task_ins) if task_ins else None

# Remember `metadata` of the in message
state[KEY_METADATA] = copy(in_message.metadata) if in_message else None

# Return the message if available
return message_from_taskins(task_ins) if task_ins is not None else None
return in_message

def send(message: Message) -> None:
"""Send task result back to server."""
# Get Node
if node_store[KEY_NODE] is None:
log(ERROR, "Node instance missing")
return
node: Node = cast(Node, node_store[KEY_NODE])

# Get incoming TaskIns
if state[KEY_TASK_INS] is None:
log(ERROR, "No current TaskIns")
# Get incoming message
in_metadata = state[KEY_METADATA]
if in_metadata is None:
log(ERROR, "No current message")
return

# Validate out message
if not validate_out_message(message, in_metadata):
log(ERROR, "Invalid out message")
return
task_ins: TaskIns = cast(TaskIns, state[KEY_TASK_INS])

# Construct TaskRes
task_res = message_to_taskres(message)

# Check if fields to be set are not initialized
if not validate_task_res(task_res):
state[KEY_TASK_INS] = None
log(ERROR, "TaskRes has been initialized accidentally")

# Configure TaskRes
task_res = configure_task_res(task_res, task_ins, node)
danieljanes marked this conversation as resolved.
Show resolved Hide resolved

# Serialize ProtoBuf to bytes
request = PushTaskResRequest(task_res_list=[task_res])
_ = stub.PushTaskRes(request)

state[KEY_TASK_INS] = None
state[KEY_METADATA] = None

try:
# Yield methods
Expand Down
46 changes: 21 additions & 25 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,7 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]:
reason = cast(int, disconnect_msg.disconnect_res.reason)
recordset = RecordSet()
recordset.set_configs("config", ConfigsRecord({"reason": reason}))
out_message = Message(
metadata=Metadata(
run_id=0,
message_id="",
group_id="",
node_id=0,
ttl="",
message_type="reconnect",
),
content=recordset,
)
out_message = message.create_reply(recordset, ttl="")
# Return TaskRes and sleep duration
return out_message, sleep_duration

Expand All @@ -107,7 +97,7 @@ def handle_legacy_message_from_msgtype(
client_fn: ClientFn, message: Message, context: Context
) -> Message:
"""Handle legacy message in the inner most mod."""
client = client_fn(str(message.metadata.node_id))
client = client_fn(str(message.metadata.dst_node_id))

client.set_context(context)

Expand Down Expand Up @@ -144,21 +134,10 @@ def handle_legacy_message_from_msgtype(
)
out_recordset = evaluateres_to_recordset(evaluate_res)
else:
raise ValueError(f"Invalid task type: {message_type}")
raise ValueError(f"Invalid message type: {message_type}")

# Return Message
out_message = Message(
metadata=Metadata(
run_id=0,
message_id="",
group_id="",
node_id=0,
ttl="",
message_type=message_type,
),
content=out_recordset,
)
return out_message
return message.create_reply(out_recordset, ttl="")


def _reconnect(
Expand All @@ -173,3 +152,20 @@ def _reconnect(
# Build DisconnectRes message
disconnect_res = ClientMessage.DisconnectRes(reason=reason)
return ClientMessage(disconnect_res=disconnect_res), sleep_duration


def validate_out_message(out_message: Message, in_message_metadata: Metadata) -> bool:
"""Validate the out message."""
out_meta = out_message.metadata
in_meta = in_message_metadata
if ( # pylint: disable-next=too-many-boolean-expressions
out_meta.run_id == in_meta.run_id
and out_meta.message_id == "" # This will be generated by the server
and out_meta.src_node_id == in_meta.dst_node_id
and out_meta.dst_node_id == in_meta.src_node_id
and out_meta.reply_to_message == in_meta.message_id
and out_meta.group_id == in_meta.group_id
and out_meta.message_type == in_meta.message_type
):
return True
return False
117 changes: 106 additions & 11 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
"""Client-side message handler tests."""


import unittest
import uuid
from copy import copy
from typing import List

from flwr.client import Client
from flwr.client.typing import ClientFn
Expand All @@ -40,7 +43,7 @@
from flwr.common import typing
from flwr.common.constant import MESSAGE_TYPE_GET_PROPERTIES

from .message_handler import handle_legacy_message_from_msgtype
from .message_handler import handle_legacy_message_from_msgtype, validate_out_message


class ClientWithoutProps(Client):
Expand Down Expand Up @@ -122,10 +125,12 @@ def test_client_without_get_properties() -> None:
recordset = compat.getpropertiesins_to_recordset(GetPropertiesIns({}))
message = Message(
metadata=Metadata(
run_id=0,
run_id=123,
message_id=str(uuid.uuid4()),
group_id="",
node_id=0,
group_id="some group ID",
src_node_id=0,
dst_node_id=1123,
reply_to_message="",
ttl="",
message_type=MESSAGE_TYPE_GET_PROPERTIES,
),
Expand All @@ -148,10 +153,22 @@ def test_client_without_get_properties() -> None:
properties={},
)
expected_rs = compat.getpropertiesres_to_recordset(expected_get_properties_res)
expected_msg = Message(message.metadata, expected_rs)
expected_msg = Message(
metadata=Metadata(
run_id=123,
message_id="",
group_id="some group ID",
src_node_id=1123,
dst_node_id=0,
reply_to_message=message.metadata.message_id,
ttl="",
message_type=MESSAGE_TYPE_GET_PROPERTIES,
),
content=expected_rs,
)

assert actual_msg.content == expected_msg.content
assert actual_msg.metadata.message_type == expected_msg.metadata.message_type
assert actual_msg.metadata == expected_msg.metadata


def test_client_with_get_properties() -> None:
Expand All @@ -161,10 +178,12 @@ def test_client_with_get_properties() -> None:
recordset = compat.getpropertiesins_to_recordset(GetPropertiesIns({}))
message = Message(
metadata=Metadata(
run_id=0,
run_id=123,
message_id=str(uuid.uuid4()),
group_id="",
node_id=0,
group_id="some group ID",
src_node_id=0,
dst_node_id=1123,
reply_to_message="",
ttl="",
message_type=MESSAGE_TYPE_GET_PROPERTIES,
),
Expand All @@ -187,7 +206,83 @@ def test_client_with_get_properties() -> None:
properties={"str_prop": "val", "int_prop": 1},
)
expected_rs = compat.getpropertiesres_to_recordset(expected_get_properties_res)
expected_msg = Message(message.metadata, expected_rs)
expected_msg = Message(
metadata=Metadata(
run_id=123,
message_id="",
group_id="some group ID",
src_node_id=1123,
dst_node_id=0,
reply_to_message=message.metadata.message_id,
ttl="",
message_type=MESSAGE_TYPE_GET_PROPERTIES,
),
content=expected_rs,
)

assert actual_msg.content == expected_msg.content
assert actual_msg.metadata.message_type == expected_msg.metadata.message_type
assert actual_msg.metadata == expected_msg.metadata


class TestMessageValidation(unittest.TestCase):
"""Test message validation."""

def setUp(self) -> None:
"""Set up the message validation."""
# Common setup for tests
self.in_metadata = Metadata(
run_id=123,
message_id="qwerty",
src_node_id=10,
dst_node_id=20,
reply_to_message="",
group_id="group1",
ttl="60",
message_type="mock",
)
self.valid_out_metadata = Metadata(
run_id=123,
message_id="",
src_node_id=20,
dst_node_id=10,
reply_to_message="qwerty",
group_id="group1",
ttl="60",
message_type="mock",
)
self.common_content = RecordSet()

def test_valid_message(self) -> None:
"""Test a valid message."""
# Prepare
valid_message = Message(metadata=self.valid_out_metadata, content=RecordSet())

# Assert
self.assertTrue(validate_out_message(valid_message, self.in_metadata))

def test_invalid_message_run_id(self) -> None:
"""Test invalid messages."""
# Prepare
msg = Message(metadata=self.valid_out_metadata, content=RecordSet())

# Execute
invalid_metadata_list: List[Metadata] = []
attrs = list(vars(self.valid_out_metadata).keys())
for attr in attrs:
if attr == "_ttl": # Skip configurable ttl
continue
# Make an invalid metadata
invalid_metadata = copy(self.valid_out_metadata)
value = getattr(invalid_metadata, attr)
if isinstance(value, int):
value = 999
elif isinstance(value, str):
value = "999"
setattr(invalid_metadata, attr, value)
# Add to list
invalid_metadata_list.append(invalid_metadata)

# Assert
for invalid_metadata in invalid_metadata_list:
msg._metadata = invalid_metadata # pylint: disable=protected-access
self.assertFalse(validate_out_message(msg, self.in_metadata))
Loading
Loading