From ee47a077a7512183ccd5df72227da095a2227c53 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 2 May 2024 09:37:22 +0200 Subject: [PATCH] Switch private and public key order in State --- .../server/superlink/state/in_memory_state.py | 8 ++++---- .../server/superlink/state/sqlite_state.py | 18 ++++++++--------- src/py/flwr/server/superlink/state/state.py | 6 +++--- .../flwr/server/superlink/state/state_test.py | 20 +++++++++---------- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index ebccac3509f..3c26e614bfe 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -254,16 +254,16 @@ def create_run(self, fab_id: str, fab_version: str) -> int: log(ERROR, "Unexpected run creation failure.") return 0 - def store_server_public_private_key( - self, public_key: bytes, private_key: bytes + def store_server_private_public_key( + self, private_key: bytes, public_key: bytes ) -> None: - """Store `server_public_key` and `server_private_key` in state.""" + """Store `server_private_key` and `server_public_key` in state.""" with self.lock: if self.server_private_key is None and self.server_public_key is None: self.server_private_key = private_key self.server_public_key = public_key else: - raise RuntimeError("Server public and private key already set") + raise RuntimeError("Server private and public key already set") def get_server_private_key(self) -> Optional[bytes]: """Retrieve `server_private_key` in urlsafe bytes.""" diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 39ed9263790..30dc03bdb95 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -42,8 +42,8 @@ SQL_CREATE_TABLE_CREDENTIAL = """ CREATE TABLE IF NOT EXISTS credential( - public_key BLOB PRIMARY KEY, - private_key BLOB + private_key BLOB PRIMARY KEY, + public_key BLOB ); """ @@ -589,20 +589,20 @@ def create_run(self, fab_id: str, fab_version: str) -> int: log(ERROR, "Unexpected run creation failure.") return 0 - def store_server_public_private_key( - self, public_key: bytes, private_key: bytes + def store_server_private_public_key( + self, private_key: bytes, public_key: bytes ) -> None: - """Store `server_public_key` and `server_private_key` in state.""" + """Store `server_private_key` and `server_public_key` in state.""" query = "SELECT COUNT(*) FROM credential" count = self.query(query)[0]["COUNT(*)"] if count < 1: query = ( - "INSERT OR REPLACE INTO credential (public_key, private_key) " - "VALUES (:public_key, :private_key)" + "INSERT OR REPLACE INTO credential (private_key, public_key) " + "VALUES (:private_key, :public_key)" ) - self.query(query, {"public_key": public_key, "private_key": private_key}) + self.query(query, {"private_key": private_key, "public_key": public_key}) else: - raise RuntimeError("Server public and private key already set") + raise RuntimeError("Server private and public key already set") def get_server_private_key(self) -> Optional[bytes]: """Retrieve `server_private_key` in urlsafe bytes.""" diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index 7992aa2345a..caaaeeabafa 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -172,10 +172,10 @@ def get_run(self, run_id: int) -> Tuple[int, str, str]: """ @abc.abstractmethod - def store_server_public_private_key( - self, public_key: bytes, private_key: bytes + def store_server_private_public_key( + self, private_key: bytes, public_key: bytes ) -> None: - """Store `server_public_key` and `server_private_key` in state.""" + """Store `server_private_key` and `server_public_key` in state.""" @abc.abstractmethod def get_server_private_key(self) -> Optional[bytes]: diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 0aeb7b064ad..c36bea506fc 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -414,8 +414,8 @@ def test_num_task_res(self) -> None: # Assert assert num == 2 - def test_server_public_private_key(self) -> None: - """Test get server public and private key after inserting.""" + def test_server_private_public_key(self) -> None: + """Test get server private and public key after inserting.""" # Prepare state: State = self.state_factory() private_key, public_key = generate_key_pairs() @@ -423,7 +423,7 @@ def test_server_public_private_key(self) -> None: public_key_bytes = public_key_to_bytes(public_key) # Execute - state.store_server_public_private_key(public_key_bytes, private_key_bytes) + state.store_server_private_public_key(private_key_bytes, public_key_bytes) server_private_key = state.get_server_private_key() server_public_key = state.get_server_public_key() @@ -431,8 +431,8 @@ def test_server_public_private_key(self) -> None: assert server_private_key == private_key_bytes assert server_public_key == public_key_bytes - def test_server_public_private_key_none(self) -> None: - """Test get server public and private key without inserting.""" + def test_server_private_public_key_none(self) -> None: + """Test get server private and public key without inserting.""" # Prepare state: State = self.state_factory() @@ -444,8 +444,8 @@ def test_server_public_private_key_none(self) -> None: assert server_private_key is None assert server_public_key is None - def test_store_server_public_private_key_twice(self) -> None: - """Test inserting public and private key twice.""" + def test_store_server_private_public_key_twice(self) -> None: + """Test inserting private and public key twice.""" # Prepare state: State = self.state_factory() private_key, public_key = generate_key_pairs() @@ -456,12 +456,12 @@ def test_store_server_public_private_key_twice(self) -> None: new_public_key_bytes = public_key_to_bytes(new_public_key) # Execute - state.store_server_public_private_key(public_key_bytes, private_key_bytes) + state.store_server_private_public_key(private_key_bytes, public_key_bytes) # Assert with self.assertRaises(RuntimeError): - state.store_server_public_private_key( - new_public_key_bytes, new_private_key_bytes + state.store_server_private_public_key( + new_private_key_bytes, new_public_key_bytes ) def test_client_public_keys(self) -> None: