Skip to content

Commit

Permalink
Speed up decrypting frames (#944)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Sep 3, 2024
1 parent 0969e93 commit 7e7ece4
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 19 deletions.
20 changes: 18 additions & 2 deletions aioesphomeapi/_frame_helper/noise.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,21 @@ cdef unsigned int NOISE_STATE_READY
cdef unsigned int NOISE_STATE_CLOSED

cdef bytes NOISE_HELLO
cdef object PACK_NONCE

cdef class EncryptCipher:

cdef object _nonce
cdef object _encrypt

cdef bytes encrypt(self, object frame)

cdef class DecryptCipher:

cdef object _nonce
cdef object _decrypt

cdef bytes decrypt(self, object frame)

cdef class APINoiseFrameHelper(APIFrameHelper):

Expand All @@ -20,8 +35,8 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
cdef unsigned int _state
cdef object _server_name
cdef object _proto
cdef object _decrypt
cdef object _encrypt
cdef EncryptCipher _encrypt_cipher
cdef DecryptCipher _decrypt_cipher

@cython.locals(
header=bytes,
Expand Down Expand Up @@ -59,6 +74,7 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
@cython.locals(
type_="unsigned int",
data=bytes,
data_header=bytes,
packet=tuple,
data_len=cython.uint,
frame=bytes,
Expand Down
70 changes: 53 additions & 17 deletions aioesphomeapi/_frame_helper/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
from functools import partial
import logging
from struct import Struct
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any

from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable
from cryptography.exceptions import InvalidTag
from noise.backends.default import DefaultNoiseBackend # type: ignore[import-untyped]
from noise.backends.default.ciphers import ( # type: ignore[import-untyped]
ChaCha20Cipher,
CryptographyCipher,
)
from noise.connection import NoiseConnection # type: ignore[import-untyped]
from noise.state import CipherState # type: ignore[import-untyped]

from ..core import (
APIConnectionError,
Expand All @@ -30,6 +32,8 @@

PACK_NONCE = partial(Struct("<LQ").pack, 0)

_bytes = bytes


class ChaCha20CipherReuseable(ChaCha20Cipher): # type: ignore[misc]
"""ChaCha20 cipher that can be reused."""
Expand Down Expand Up @@ -68,6 +72,44 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
int_ = int


class EncryptCipher:
"""Wrapper around the ChaCha20Poly1305 cipher for encryption."""

__slots__ = ("_nonce", "_encrypt")

def __init__(self, cipher_state: CipherState) -> None:
"""Initialize the cipher wrapper."""
crypto_cipher: CryptographyCipher = cipher_state.cipher
cipher: ChaCha20Poly1305Reusable = crypto_cipher.cipher
self._nonce: int = cipher_state.n
self._encrypt = cipher.encrypt

def encrypt(self, data: _bytes) -> bytes:
"""Encrypt a frame."""
ciphertext = self._encrypt(PACK_NONCE(self._nonce), data, None)
self._nonce += 1
return ciphertext


class DecryptCipher:
"""Wrapper around the ChaCha20Poly1305 cipher for decryption."""

__slots__ = ("_nonce", "_decrypt")

def __init__(self, cipher_state: CipherState) -> None:
"""Initialize the cipher wrapper."""
crypto_cipher: CryptographyCipher = cipher_state.cipher
cipher: ChaCha20Poly1305Reusable = crypto_cipher.cipher
self._nonce: int = cipher_state.n
self._decrypt = cipher.decrypt

def decrypt(self, data: _bytes) -> bytes:
"""Decrypt a frame."""
plaintext = self._decrypt(PACK_NONCE(self._nonce), data, None)
self._nonce += 1
return plaintext


class APINoiseFrameHelper(APIFrameHelper):
"""Frame helper for noise encrypted connections."""

Expand All @@ -77,8 +119,8 @@ class APINoiseFrameHelper(APIFrameHelper):
"_state",
"_server_name",
"_proto",
"_decrypt",
"_encrypt",
"_encrypt_cipher",
"_decrypt_cipher",
)

def __init__(
Expand All @@ -95,8 +137,8 @@ def __init__(
self._expected_name = expected_name
self._state = NOISE_STATE_HELLO
self._server_name: str | None = None
self._decrypt: Callable[[bytes], bytes] | None = None
self._encrypt: Callable[[bytes], bytes] | None = None
self._encrypt_cipher: EncryptCipher | None = None
self._decrypt_cipher: DecryptCipher | None = None
self._setup_proto()

def close(self) -> None:
Expand Down Expand Up @@ -271,14 +313,8 @@ def _handle_handshake(self, msg: bytes) -> None:
self._proto.read_message(msg[1:])
self._state = NOISE_STATE_READY
noise_protocol = self._proto.noise_protocol
self._decrypt = partial(
noise_protocol.cipher_state_decrypt.decrypt_with_ad, # pylint: disable=no-member
None,
)
self._encrypt = partial(
noise_protocol.cipher_state_encrypt.encrypt_with_ad, # pylint: disable=no-member
None,
)
self._decrypt_cipher = DecryptCipher(noise_protocol.cipher_state_decrypt) # pylint: disable=no-member
self._encrypt_cipher = EncryptCipher(noise_protocol.cipher_state_encrypt) # pylint: disable=no-member
self.ready_future.set_result(None)

def write_packets(
Expand All @@ -289,7 +325,7 @@ def write_packets(
Packets are in the format of tuple[protobuf_type, protobuf_data]
"""
if TYPE_CHECKING:
assert self._encrypt is not None, "Handshake should be complete"
assert self._encrypt_cipher is not None, "Handshake should be complete"

out: list[bytes] = []
for packet in packets:
Expand All @@ -304,7 +340,7 @@ def write_packets(
data_len & 0xFF,
)
)
frame = self._encrypt(data_header + data)
frame = self._encrypt_cipher.encrypt(data_header + data)
frame_len = len(frame)
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
out.append(header)
Expand All @@ -315,8 +351,8 @@ def write_packets(
def _handle_frame(self, frame: bytes) -> None:
"""Handle an incoming frame."""
if TYPE_CHECKING:
assert self._decrypt is not None, "Handshake should be complete"
msg = self._decrypt(frame)
assert self._decrypt_cipher is not None, "Handshake should be complete"
msg = self._decrypt_cipher.decrypt(frame)
# Message layout is
# 2 bytes: message type
# 2 bytes: message length
Expand Down

0 comments on commit 7e7ece4

Please sign in to comment.