Skip to content

Commit

Permalink
Move varuint functions into plain_text frame_helper (#587)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Oct 17, 2023
1 parent b059dd6 commit 63897ed
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 51 deletions.
5 changes: 5 additions & 0 deletions aioesphomeapi/_frame_helper/plain_text.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ cdef bint TYPE_CHECKING
cdef object WRITE_EXCEPTIONS
cdef object bytes_to_varuint, varuint_to_bytes

cpdef _varuint_to_bytes(cython.int value)

@cython.locals(result=cython.int, bitpos=cython.int, val=cython.int)
cpdef _bytes_to_varuint(cython.bytes value)

cdef class APIPlaintextFrameHelper(APIFrameHelper):

@cython.locals(
Expand Down
42 changes: 41 additions & 1 deletion aioesphomeapi/_frame_helper/plain_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,54 @@

import asyncio
import logging
from functools import lru_cache
from typing import TYPE_CHECKING

from ..core import ProtocolAPIError, RequiresEncryptionAPIError, SocketAPIError
from ..util import bytes_to_varuint, varuint_to_bytes
from .base import WRITE_EXCEPTIONS, APIFrameHelper

_LOGGER = logging.getLogger(__name__)

_int = int
_bytes = bytes


def _varuint_to_bytes(value: _int) -> bytes:
"""Convert a varuint to bytes."""
if value <= 0x7F:
return bytes((value,))

result = []
while value:
temp = value & 0x7F
value >>= 7
if value:
result.append(temp | 0x80)
else:
result.append(temp)

return bytes(result)


_cached_varuint_to_bytes = lru_cache(maxsize=1024)(_varuint_to_bytes)
varuint_to_bytes = _cached_varuint_to_bytes


def _bytes_to_varuint(value: _bytes) -> _int | None:
"""Convert bytes to a varuint."""
result = 0
bitpos = 0
for val in value:
result |= (val & 0x7F) << bitpos
if (val & 0x80) == 0:
return result
bitpos += 7
return None


_cached_bytes_to_varuint = lru_cache(maxsize=1024)(_bytes_to_varuint)
bytes_to_varuint = _cached_bytes_to_varuint


class APIPlaintextFrameHelper(APIFrameHelper):
"""Frame helper for plaintext API connections."""
Expand Down
30 changes: 0 additions & 30 deletions aioesphomeapi/util.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,6 @@
from __future__ import annotations

import math
from functools import lru_cache


@lru_cache(maxsize=1024)
def varuint_to_bytes(value: int) -> bytes:
if value <= 0x7F:
return bytes([value])

ret = b""
while value:
temp = value & 0x7F
value >>= 7
if value:
ret += bytes([temp | 0x80])
else:
ret += bytes([temp])

return ret


@lru_cache(maxsize=1024)
def bytes_to_varuint(value: bytes) -> int | None:
result = 0
bitpos = 0
for val in value:
result |= (val & 0x7F) << bitpos
if (val & 0x80) == 0:
return result
bitpos += 7
return None


def fix_float_single_double_conversion(value: float) -> float:
Expand Down
31 changes: 30 additions & 1 deletion tests/test__frame_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,19 @@

from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS
from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint
from aioesphomeapi._frame_helper.plain_text import (
_cached_bytes_to_varuint as cached_bytes_to_varuint,
)
from aioesphomeapi._frame_helper.plain_text import (
_cached_varuint_to_bytes as cached_varuint_to_bytes,
)
from aioesphomeapi._frame_helper.plain_text import _varuint_to_bytes as varuint_to_bytes
from aioesphomeapi.core import (
BadNameAPIError,
InvalidEncryptionKeyAPIError,
SocketAPIError,
)
from aioesphomeapi.util import varuint_to_bytes

PREAMBLE = b"\x00"

Expand Down Expand Up @@ -234,3 +241,25 @@ def _on_error(exc: Exception):

with pytest.raises(BadNameAPIError):
await helper.perform_handshake(30)


VARUINT_TESTCASES = [
(0, b"\x00"),
(42, b"\x2a"),
(127, b"\x7f"),
(128, b"\x80\x01"),
(300, b"\xac\x02"),
(65536, b"\x80\x80\x04"),
]


@pytest.mark.parametrize("val, encoded", VARUINT_TESTCASES)
def test_varuint_to_bytes(val, encoded):
assert varuint_to_bytes(val) == encoded
assert cached_varuint_to_bytes(val) == encoded


@pytest.mark.parametrize("val, encoded", VARUINT_TESTCASES)
def test_bytes_to_varuint(val, encoded):
assert bytes_to_varuint(encoded) == val
assert cached_bytes_to_varuint(encoded) == val
19 changes: 0 additions & 19 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,6 @@

from aioesphomeapi import util

VARUINT_TESTCASES = [
(0, b"\x00"),
(42, b"\x2a"),
(127, b"\x7f"),
(128, b"\x80\x01"),
(300, b"\xac\x02"),
(65536, b"\x80\x80\x04"),
]


@pytest.mark.parametrize("val, encoded", VARUINT_TESTCASES)
def test_varuint_to_bytes(val, encoded):
assert util.varuint_to_bytes(val) == encoded


@pytest.mark.parametrize("val, encoded", VARUINT_TESTCASES)
def test_bytes_to_varuint(val, encoded):
assert util.bytes_to_varuint(encoded) == val


@pytest.mark.parametrize(
"input, output",
Expand Down

0 comments on commit 63897ed

Please sign in to comment.