Skip to content

Commit

Permalink
Small reduction in connect overhead (#578)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Oct 15, 2023
1 parent 4be79ea commit c5f4bfa
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 42 deletions.
24 changes: 23 additions & 1 deletion aioesphomeapi/connection.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,36 @@ cdef float KEEP_ALIVE_TIMEOUT_RATIO
cdef bint TYPE_CHECKING

cdef object DISCONNECT_REQUEST_MESSAGE
cdef object DISCONNECT_RESPONSE_MESSAGE
cdef object PING_REQUEST_MESSAGE
cdef object PING_RESPONSE_MESSAGE

cdef object asyncio_timeout
cdef object CancelledError
cdef object asyncio_TimeoutError

cdef object ConnectResponse
cdef object DisconnectRequest
cdef object PingRequest
cdef object GetTimeRequest
cdef object GetTimeRequest, GetTimeResponse

cdef object APIVersion

cdef object partial

cdef object hr

cdef object RESOLVE_TIMEOUT
cdef object CONNECT_AND_SETUP_TIMEOUT

cdef object APIConnectionError
cdef object BadNameAPIError
cdef object HandshakeAPIError
cdef object PingFailedAPIError
cdef object ReadFailedAPIError
cdef object TimeoutAPIError


cdef class APIConnection:

cdef object _params
Expand Down
87 changes: 46 additions & 41 deletions aioesphomeapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
import socket
import sys
import time

# After we drop support for Python 3.10, we can use the built-in TimeoutError
# instead of the one from asyncio since they are the same in Python 3.11+
from asyncio import CancelledError
from asyncio import TimeoutError as asyncio_TimeoutError
from collections.abc import Coroutine
from dataclasses import astuple, dataclass
from functools import partial
Expand Down Expand Up @@ -60,6 +65,7 @@
INTERNAL_MESSAGE_TYPES = {GetTimeRequest, PingRequest, DisconnectRequest}

DISCONNECT_REQUEST_MESSAGE = DisconnectRequest()
DISCONNECT_RESPONSE_MESSAGE = DisconnectResponse()
PING_REQUEST_MESSAGE = PingRequest()
PING_RESPONSE_MESSAGE = PingResponse()

Expand Down Expand Up @@ -187,8 +193,9 @@ def __init__(

self._ping_timer: asyncio.TimerHandle | None = None
self._pong_timer: asyncio.TimerHandle | None = None
self._keep_alive_interval = params.keepalive
self._keep_alive_timeout = params.keepalive * KEEP_ALIVE_TIMEOUT_RATIO
keepalive = params.keepalive
self._keep_alive_interval = keepalive
self._keep_alive_timeout = keepalive * KEEP_ALIVE_TIMEOUT_RATIO

self._start_connect_task: asyncio.Task[None] | None = None
self._finish_connect_task: asyncio.Task[None] | None = None
Expand All @@ -209,7 +216,7 @@ def _cleanup(self) -> None:
Safe to call multiple times.
"""
if self.connection_state == ConnectionState.CLOSED:
if self.connection_state is ConnectionState.CLOSED:
return
was_connected = self.is_connected
self._set_connection_state(ConnectionState.CLOSED)
Expand Down Expand Up @@ -249,7 +256,7 @@ def _cleanup(self) -> None:
self._ping_timer.cancel()
self._ping_timer = None

if self.on_stop and was_connected:
if self.on_stop is not None and was_connected:
# Ensure on_stop is called only once
self._on_stop_task = asyncio.create_task(
self.on_stop(self._expected_disconnect),
Expand Down Expand Up @@ -277,30 +284,29 @@ async def _connect_resolve_host(self) -> hr.AddrInfo:
)
async with asyncio_timeout(RESOLVE_TIMEOUT):
return await coro
except asyncio.TimeoutError as err:
except asyncio_TimeoutError as err:
raise ResolveAPIError(
f"Timeout while resolving IP address for {self.log_name}"
) from err

async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None:
"""Step 2 in connect process: connect the socket."""
self._socket = socket.socket(
family=addr.family, type=addr.type, proto=addr.proto
)
self._socket.setblocking(False)
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
debug_enable = self._debug_enabled()
sock = socket.socket(family=addr.family, type=addr.type, proto=addr.proto)
sock.setblocking(False)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# Try to reduce the pressure on esphome device as it measures
# ram in bytes and we measure ram in megabytes.
try:
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)
except OSError as err:
_LOGGER.warning(
"%s: Failed to set socket receive buffer size: %s",
self.log_name,
err,
)

if self._debug_enabled():
if debug_enable is True:
_LOGGER.debug(
"%s: Connecting to %s:%s (%s)",
self.log_name,
Expand All @@ -311,29 +317,30 @@ async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None:
sockaddr = astuple(addr.sockaddr)

try:
coro = self._loop.sock_connect(self._socket, sockaddr)
async with asyncio_timeout(TCP_CONNECT_TIMEOUT):
await coro
except asyncio.TimeoutError as err:
await self._loop.sock_connect(sock, sockaddr)
except asyncio_TimeoutError as err:
raise SocketAPIError(f"Timeout while connecting to {sockaddr}") from err
except OSError as err:
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err

_LOGGER.debug(
"%s: Opened socket to %s:%s (%s)",
self.log_name,
self._params.address,
self._params.port,
addr,
)
self._socket = sock
if debug_enable is True:
_LOGGER.debug(
"%s: Opened socket to %s:%s (%s)",
self.log_name,
self._params.address,
self._params.port,
addr,
)

async def _connect_init_frame_helper(self) -> None:
"""Step 3 in connect process: initialize the frame helper and init read loop."""
fh: APIPlaintextFrameHelper | APINoiseFrameHelper
loop = self._loop
assert self._socket is not None

if self._params.noise_psk is None:
if (noise_psk := self._params.noise_psk) is None:
_, fh = await loop.create_connection( # type: ignore[type-var]
lambda: APIPlaintextFrameHelper(
on_pkt=self._process_packet,
Expand All @@ -345,11 +352,9 @@ async def _connect_init_frame_helper(self) -> None:
)
else:
# Ensure noise_psk is a string and not an EStr
noise_psk = str(self._params.noise_psk)
assert noise_psk is not None
_, fh = await loop.create_connection( # type: ignore[type-var]
lambda: APINoiseFrameHelper(
noise_psk=noise_psk,
noise_psk=str(noise_psk),
expected_name=self._params.expected_name,
on_pkt=self._process_packet,
on_error=self._report_fatal_error,
Expand All @@ -359,14 +364,14 @@ async def _connect_init_frame_helper(self) -> None:
sock=self._socket,
)

self._frame_helper = fh
try:
await fh.perform_handshake(HANDSHAKE_TIMEOUT)
except asyncio.TimeoutError as err:
except asyncio_TimeoutError as err:
raise TimeoutAPIError("Handshake timed out") from err
except OSError as err:
raise HandshakeAPIError(f"Handshake failed: {err}") from err
self._set_connection_state(ConnectionState.HANDSHAKE_COMPLETE)
self._frame_helper = fh

async def _connect_hello(self) -> None:
"""Step 4 in connect process: send hello and get api version."""
Expand Down Expand Up @@ -433,7 +438,7 @@ def _async_send_keep_alive(self) -> None:
self._pong_timer = loop.call_at(
now + self._keep_alive_timeout, self._async_pong_not_received
)
elif self._debug_enabled():
elif self._debug_enabled() is True:
#
# We haven't reached the ping response (pong) timeout yet
# and we haven't seen a response to the last ping
Expand Down Expand Up @@ -500,11 +505,11 @@ async def start_connection(self) -> None:
# does not have a timeout
async with asyncio_timeout(CONNECT_AND_SETUP_TIMEOUT):
await start_connect_task
except (Exception, asyncio.CancelledError) as ex:
except (Exception, CancelledError) as ex:
# If the task was cancelled, we need to clean up the connection
# and raise the CancelledError
self._cleanup()
if isinstance(ex, asyncio.CancelledError):
if isinstance(ex, CancelledError):
raise self._fatal_exception or APIConnectionError(
"Connection cancelled"
)
Expand Down Expand Up @@ -547,11 +552,11 @@ async def finish_connection(self, *, login: bool) -> None:
# does not have a timeout
async with asyncio_timeout(CONNECT_AND_SETUP_TIMEOUT):
await self._finish_connect_task
except (Exception, asyncio.CancelledError) as ex:
except (Exception, CancelledError) as ex:
# If the task was cancelled, we need to clean up the connection
# and raise the CancelledError
self._cleanup()
if isinstance(ex, asyncio.CancelledError):
if isinstance(ex, CancelledError):
raise self._fatal_exception or APIConnectionError(
"Connection cancelled"
)
Expand All @@ -567,8 +572,8 @@ async def finish_connection(self, *, login: bool) -> None:
def _set_connection_state(self, state: ConnectionState) -> None:
"""Set the connection state and log the change."""
self.connection_state = state
self.is_connected = state == ConnectionState.CONNECTED
self._handshake_complete = state == ConnectionState.HANDSHAKE_COMPLETE
self.is_connected = state is ConnectionState.CONNECTED
self._handshake_complete = state is ConnectionState.HANDSHAKE_COMPLETE

async def _login(self) -> None:
"""Send a login (ConnectRequest) and await the response."""
Expand Down Expand Up @@ -606,7 +611,7 @@ def send_message(self, msg: message.Message) -> None:
if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None:
raise ValueError(f"Message type id not found for type {msg_type}")

if self._debug_enabled():
if self._debug_enabled() is True:
_LOGGER.debug("%s: Sending %s: %s", self.log_name, msg_type.__name__, msg)

if TYPE_CHECKING:
Expand Down Expand Up @@ -667,7 +672,7 @@ def _handle_timeout(self, fut: asyncio.Future[None]) -> None:
"""Handle a timeout."""
if fut.done():
return
fut.set_exception(asyncio.TimeoutError)
fut.set_exception(asyncio_TimeoutError)

def _handle_complex_message(
self,
Expand Down Expand Up @@ -727,7 +732,7 @@ async def send_message_await_response_complex( # pylint: disable=too-many-local
timeout_expired = False
try:
await fut
except asyncio.TimeoutError as err:
except asyncio_TimeoutError as err:
timeout_expired = True
raise TimeoutAPIError(
f"Timeout waiting for response for {type(send_msg)} after {timeout}s"
Expand Down Expand Up @@ -761,7 +766,7 @@ def _report_fatal_error(self, err: Exception) -> None:
The connection will be closed, all exception handlers notified.
This method does not log the error, the call site should do so.
"""
if not self._expected_disconnect and not self._fatal_exception:
if self._expected_disconnect is False and not self._fatal_exception:
# Only log the first error
_LOGGER.warning(
"%s: Connection error occurred: %s",
Expand Down Expand Up @@ -806,7 +811,7 @@ def _process_packet(self, msg_type_proto: _int, data: _bytes) -> None:

msg_type = type(msg)

if self._debug_enabled():
if self._debug_enabled() is True:
_LOGGER.debug(
"%s: Got message of type %s: %s",
self.log_name,
Expand All @@ -830,7 +835,7 @@ def _process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
handler(msg)

if msg_type is DisconnectRequest:
self.send_message(DisconnectResponse())
self.send_message(DISCONNECT_RESPONSE_MESSAGE)
self._expected_disconnect = True
self._cleanup()
elif msg_type is PingRequest:
Expand Down

0 comments on commit c5f4bfa

Please sign in to comment.