Skip to content

Commit

Permalink
Switch private and public key order in State (#3386)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielnugraha authored May 2, 2024
1 parent bfcb4af commit 91ab0cf
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 26 deletions.
8 changes: 4 additions & 4 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,16 +254,16 @@ def create_run(self, fab_id: str, fab_version: str) -> int:
log(ERROR, "Unexpected run creation failure.")
return 0

def store_server_public_private_key(
self, public_key: bytes, private_key: bytes
def store_server_private_public_key(
self, private_key: bytes, public_key: bytes
) -> None:
"""Store `server_public_key` and `server_private_key` in state."""
"""Store `server_private_key` and `server_public_key` in state."""
with self.lock:
if self.server_private_key is None and self.server_public_key is None:
self.server_private_key = private_key
self.server_public_key = public_key
else:
raise RuntimeError("Server public and private key already set")
raise RuntimeError("Server private and public key already set")

def get_server_private_key(self) -> Optional[bytes]:
"""Retrieve `server_private_key` in urlsafe bytes."""
Expand Down
18 changes: 9 additions & 9 deletions src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@

SQL_CREATE_TABLE_CREDENTIAL = """
CREATE TABLE IF NOT EXISTS credential(
public_key BLOB PRIMARY KEY,
private_key BLOB
private_key BLOB PRIMARY KEY,
public_key BLOB
);
"""

Expand Down Expand Up @@ -589,20 +589,20 @@ def create_run(self, fab_id: str, fab_version: str) -> int:
log(ERROR, "Unexpected run creation failure.")
return 0

def store_server_public_private_key(
self, public_key: bytes, private_key: bytes
def store_server_private_public_key(
self, private_key: bytes, public_key: bytes
) -> None:
"""Store `server_public_key` and `server_private_key` in state."""
"""Store `server_private_key` and `server_public_key` in state."""
query = "SELECT COUNT(*) FROM credential"
count = self.query(query)[0]["COUNT(*)"]
if count < 1:
query = (
"INSERT OR REPLACE INTO credential (public_key, private_key) "
"VALUES (:public_key, :private_key)"
"INSERT OR REPLACE INTO credential (private_key, public_key) "
"VALUES (:private_key, :public_key)"
)
self.query(query, {"public_key": public_key, "private_key": private_key})
self.query(query, {"private_key": private_key, "public_key": public_key})
else:
raise RuntimeError("Server public and private key already set")
raise RuntimeError("Server private and public key already set")

def get_server_private_key(self) -> Optional[bytes]:
"""Retrieve `server_private_key` in urlsafe bytes."""
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/server/superlink/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,10 @@ def get_run(self, run_id: int) -> Tuple[int, str, str]:
"""

@abc.abstractmethod
def store_server_public_private_key(
self, public_key: bytes, private_key: bytes
def store_server_private_public_key(
self, private_key: bytes, public_key: bytes
) -> None:
"""Store `server_public_key` and `server_private_key` in state."""
"""Store `server_private_key` and `server_public_key` in state."""

@abc.abstractmethod
def get_server_private_key(self) -> Optional[bytes]:
Expand Down
20 changes: 10 additions & 10 deletions src/py/flwr/server/superlink/state/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,25 +414,25 @@ def test_num_task_res(self) -> None:
# Assert
assert num == 2

def test_server_public_private_key(self) -> None:
"""Test get server public and private key after inserting."""
def test_server_private_public_key(self) -> None:
"""Test get server private and public key after inserting."""
# Prepare
state: State = self.state_factory()
private_key, public_key = generate_key_pairs()
private_key_bytes = private_key_to_bytes(private_key)
public_key_bytes = public_key_to_bytes(public_key)

# Execute
state.store_server_public_private_key(public_key_bytes, private_key_bytes)
state.store_server_private_public_key(private_key_bytes, public_key_bytes)
server_private_key = state.get_server_private_key()
server_public_key = state.get_server_public_key()

# Assert
assert server_private_key == private_key_bytes
assert server_public_key == public_key_bytes

def test_server_public_private_key_none(self) -> None:
"""Test get server public and private key without inserting."""
def test_server_private_public_key_none(self) -> None:
"""Test get server private and public key without inserting."""
# Prepare
state: State = self.state_factory()

Expand All @@ -444,8 +444,8 @@ def test_server_public_private_key_none(self) -> None:
assert server_private_key is None
assert server_public_key is None

def test_store_server_public_private_key_twice(self) -> None:
"""Test inserting public and private key twice."""
def test_store_server_private_public_key_twice(self) -> None:
"""Test inserting private and public key twice."""
# Prepare
state: State = self.state_factory()
private_key, public_key = generate_key_pairs()
Expand All @@ -456,12 +456,12 @@ def test_store_server_public_private_key_twice(self) -> None:
new_public_key_bytes = public_key_to_bytes(new_public_key)

# Execute
state.store_server_public_private_key(public_key_bytes, private_key_bytes)
state.store_server_private_public_key(private_key_bytes, public_key_bytes)

# Assert
with self.assertRaises(RuntimeError):
state.store_server_public_private_key(
new_public_key_bytes, new_private_key_bytes
state.store_server_private_public_key(
new_private_key_bytes, new_public_key_bytes
)

def test_client_public_keys(self) -> None:
Expand Down

0 comments on commit 91ab0cf

Please sign in to comment.