diff --git a/aioasuswrt/asuswrt.py b/aioasuswrt/asuswrt.py index 78c0bfb..c633d10 100644 --- a/aioasuswrt/asuswrt.py +++ b/aioasuswrt/asuswrt.py @@ -6,7 +6,7 @@ from collections import namedtuple from datetime import datetime -from aioasuswrt.connection import SshConnection, TelnetConnection +from aioasuswrt.connection import create_connection from aioasuswrt.helpers import convert_size _LOGGER = logging.getLogger(__name__) @@ -250,10 +250,9 @@ def __init__( self.interface = interface self.dnsmasq = dnsmasq - if use_telnet: - self.connection = TelnetConnection(host, port, username, password) - else: - self.connection = SshConnection(host, port, username, password, ssh_key) + self.connection = create_connection( + use_telnet, host, port, username, password, ssh_key + ) async def async_get_nvram(self, to_get): """Gets nvram""" diff --git a/aioasuswrt/connection.py b/aioasuswrt/connection.py index e1adf9c..df630de 100644 --- a/aioasuswrt/connection.py +++ b/aioasuswrt/connection.py @@ -1,9 +1,11 @@ """Module for connections.""" +import abc import asyncio -from asyncio import IncompleteReadError import logging -from asyncio import LimitOverrunError, TimeoutError +from asyncio import IncompleteReadError, LimitOverrunError, TimeoutError +from asyncio.streams import StreamReader, StreamWriter from math import floor +from typing import List, Optional import asyncssh @@ -13,59 +15,149 @@ asyncssh.set_log_level("WARNING") -class SshConnection: +class _CommandException(Exception): + pass + + +class _BaseConnection(abc.ABC): + def __init__( + self, host: str, port: int, username: Optional[str], password: Optional[str] + ): + self._host = host + self._port = port + self._username = username if username else None + self._password = password if password else None + + self._io_lock = asyncio.Lock() + + @property + def description(self) -> str: + """ Description of the connection.""" + ret = f"{self._host}:{self._port}" + if self._username: + ret = f"{self._username}@{ret}" + + return ret + + async def async_run_command(self, command: str, retry=True) -> List[str]: + """ Call a command using the connection.""" + async with self._io_lock: + if not self.is_connected: + await self.async_connect() + + try: + return await self._async_call_command(command) + except _CommandException: + pass + + # The command failed + if retry: + _LOGGER.debug(f"Retrying command: {command}") + return await self._async_call_command(command) + return [] + + async def async_connect(self): + if self.is_connected: + _LOGGER.debug(f"Connection already established to: {self.description}") + return + + await self._async_connect() + + async def async_disconnect(self): + """Disconnects the client""" + async with self._io_lock: + self._disconnect() + + @abc.abstractmethod + async def _async_call_command(self, command: str) -> List[str]: + """ Call the command.""" + pass + + @abc.abstractmethod + async def _async_connect(self): + """ Establish a connection.""" + pass + + @abc.abstractmethod + def _disconnect(self): + """ Disconnect.""" + pass + + @property + @abc.abstractmethod + def is_connected(self) -> bool: + """Do we have a connection.""" + pass + + +def create_connection( + use_telnet: bool, + host: str, + port: Optional[int], + username: Optional[str], + password: Optional[str], + ssh_key: Optional[str], +) -> _BaseConnection: + + if use_telnet: + return TelnetConnection( + host=host, port=port, username=username, password=password + ) + else: + return SshConnection( + host=host, port=port, username=username, password=password, ssh_key=ssh_key + ) + + +class SshConnection(_BaseConnection): """Maintains an SSH connection to an ASUS-WRT router.""" - def __init__(self, host, port, username, password, ssh_key): + def __init__( + self, + host: str, + port: Optional[int], + username: Optional[str], + password: Optional[str], + ssh_key: Optional[str], + ): """Initialize the SSH connection properties.""" - self._host = host - self._port = port or 22 - self._username = username - self._password = password + super().__init__(host, port or 22, username, password) self._ssh_key = ssh_key self._client = None - async def async_run_command(self, command, retry=False): + async def _async_call_command(self, command: str) -> List[str]: """Run commands through an SSH connection. Connect to the SSH server if not currently connected, otherwise use the existing connection. """ - if self._client is None and not retry: - await self.async_connect() - return await self.async_run_command(command, retry=True) - else: - if self._client is not None: - try: - result = await asyncio.wait_for( - self._client.run("%s && %s" % (_PATH_EXPORT_COMMAND, command)), - 9, - ) - except asyncssh.misc.ChannelOpenError: - if not retry: - await self.async_connect() - return await self.async_run_command(command, retry=True) - else: - _LOGGER.error("Cant connect to host, giving up!") - return [] - except TimeoutError: - self._client = None - _LOGGER.error("Host timeout.") - return [] - - return result.stdout.split("\n") - - else: - _LOGGER.error("Cant connect to host, giving up!") - return [] + try: + if not self.is_connected: + await self._async_connect() + if not self._client: + raise _CommandException + + result = await asyncio.wait_for( + self._client.run(f"{_PATH_EXPORT_COMMAND} && {command}"), + 9, + ) + except asyncssh.misc.ChannelOpenError as ex: + self._disconnect() + _LOGGER.warning("Not connected to host") + raise _CommandException from ex + except TimeoutError as ex: + self._disconnect() + _LOGGER.error("Host timeout.") + raise _CommandException from ex + + return result.stdout.split("\n") @property - def is_connected(self): + def is_connected(self) -> bool: """Do we have a connection.""" return self._client is not None - async def async_connect(self): + async def _async_connect(self): """Fetches the client or creates a new one.""" - kwargs = { "username": self._username if self._username else None, "client_keys": [self._ssh_key] if self._ssh_key else None, @@ -76,72 +168,64 @@ async def async_connect(self): self._client = await asyncssh.connect(self._host, **kwargs) + def _disconnect(self): + self._client = None + -class TelnetConnection: +class TelnetConnection(_BaseConnection): """Maintains a Telnet connection to an ASUS-WRT router.""" - def __init__(self, host, port, username, password): + def __init__( + self, + host: str, + port: Optional[int], + username: Optional[str], + password: Optional[str], + ): """Initialize the Telnet connection properties.""" - self._reader = None - self._writer = None - self._host = host - self._port = port or 23 - self._username = username - self._password = password - self._prompt_string = None - self._io_lock = asyncio.Lock() - self._linebreak = None + super().__init__(host, port or 23, username, password) + self._reader: Optional[StreamReader] = None + self._writer: Optional[StreamWriter] = None + self._prompt_string = "".encode("ascii") + self._linebreak: Optional[float] = None - async def async_run_command(self, command, first_try=True): + async def _async_call_command(self, command): """Run a command through a Telnet connection. If first_try is True a second attempt will be done if the first try fails.""" + try: + if not self.is_connected: + await self._async_connect() - need_retry = False + if self._linebreak is None: + self._linebreak = await self._async_linebreak() - async with self._io_lock: - try: - if not self.is_connected: - await self._async_connect() - # Let's add the path and send the command - full_cmd = f"{_PATH_EXPORT_COMMAND} && {command}" - self._writer.write((full_cmd + "\n").encode("ascii")) - # And read back the data till the prompt string - data = await asyncio.wait_for( - self._reader.readuntil(self._prompt_string), 9 - ) - except (BrokenPipeError, LimitOverrunError, IncompleteReadError): - # Writing has failed, Let's close and retry if necessary - self.disconnect() - if first_try: - need_retry = True - else: - _LOGGER.warning("connection is lost to host.") - return [] - except TimeoutError: - _LOGGER.error("Host timeout.") - self.disconnect() - if first_try: - need_retry = True - else: - return [] - - if need_retry: - _LOGGER.debug("Trying one more time") - return await self.async_run_command(command, False) + if not self._writer or not self._reader: + raise _CommandException + + # Let's add the path and send the command + full_cmd = f"{_PATH_EXPORT_COMMAND} && {command}" + self._writer.write((full_cmd + "\n").encode("ascii")) + # And read back the data till the prompt string + data = await asyncio.wait_for( + self._reader.readuntil(self._prompt_string), 9 + ) + except (BrokenPipeError, LimitOverrunError, IncompleteReadError) as ex: + # Writing has failed, Let's close and retry if necessary + _LOGGER.warning("connection is lost to host.") + self._disconnect() + raise _CommandException from ex + except TimeoutError as ex: + _LOGGER.error("Host timeout.") + self._disconnect() + raise _CommandException from ex # Let's process the received data - data = data.split(b"\n") + data_list = data.split(b"\n") # Let's find the number of elements the cmd takes cmd_len = len(self._prompt_string) + len(full_cmd) # We have to do floor + 1 to handle the infinite case correct start_split = floor(cmd_len / self._linebreak) + 1 - data = data[start_split:-1] - return [line.decode("utf-8") for line in data] - - async def async_connect(self): - """Connect to the ASUS-WRT Telnet server.""" - async with self._io_lock: - await self._async_connect() + return [line.decode("utf-8") for line in data_list[start_split:-1]] async def _async_connect(self): self._reader, self._writer = await asyncio.open_connection( @@ -159,50 +243,56 @@ async def _async_connect(self): return except TimeoutError: _LOGGER.error("Host timeout.") - self.disconnect() - self._writer.write((self._username + "\n").encode("ascii")) + self._disconnect() + + self._writer.write((self._username or "" + "\n").encode("ascii")) # Enter the password await self._reader.readuntil(b"Password: ") - self._writer.write((self._password + "\n").encode("ascii")) + self._writer.write((self._password or "" + "\n").encode("ascii")) # Now we can determine the prompt string for the commands. self._prompt_string = (await self._reader.readuntil(b"#")).split(b"\n")[-1] + async def _async_linebreak(self) -> float: + """Telnet or asyncio seems to be adding linebreaks due to terminal size, + try to determine here what the column number is.""" # Let's determine if any linebreaks are added # Write some arbitrary long string. - if self._linebreak is None: - self._writer.write((" " * 200 + "\n").encode("ascii")) - self._determine_linebreak( - await self._reader.readuntil(self._prompt_string) - ) + if not self._writer or not self._reader: + raise _CommandException - def _determine_linebreak(self, input_bytes: bytes): - """Telnet or asyncio seems to be adding linebreaks due to terminal size, - try to determine here what the column number is.""" + self._writer.write((" " * 200 + "\n").encode("ascii")) + input_bytes = await self._reader.readuntil(self._prompt_string) + + return self._determine_linebreak(input_bytes) + + def _determine_linebreak(self, input_bytes: bytes) -> float: # Let's convert the data to the expected format data = input_bytes.decode("utf-8").replace("\r", "").split("\n") if len(data) == 1: # There was no split, so assume infinite - self._linebreak = float("inf") + linebreak = float("inf") else: # The linebreak is the length of the prompt string + the first line - self._linebreak = len(self._prompt_string) + len(data[0]) + linebreak = len(self._prompt_string) + len(data[0]) if len(data) > 2: # We can do a quick sanity check, as there are more linebreaks - if len(data[1]) != self._linebreak: + if len(data[1]) != linebreak: _LOGGER.warning( - f"Inconsistent linebreaks {len(data[1])} != " - f"{self._linebreak}" + f"Inconsistent linebreaks {len(data[1])} != " f"{linebreak}" ) + return linebreak + @property - def is_connected(self): + def is_connected(self) -> bool: """Do we have a connection.""" return self._reader is not None and self._writer is not None - def disconnect(self): - """Disconnects the client""" + def _disconnect(self): + """ Disconnect the connection, ensure that the caller holds the io_lock.""" self._writer = None self._reader = None + self._linebreak = None diff --git a/aioasuswrt/mocks/telnet_mock.py b/aioasuswrt/mocks/telnet_mock.py index 30282b9..d218122 100644 --- a/aioasuswrt/mocks/telnet_mock.py +++ b/aioasuswrt/mocks/telnet_mock.py @@ -1,16 +1,17 @@ """ Mock library for the Telnet connection, especially mocking the reader/writer of asyncio """ +import asyncio import textwrap -import typing +from typing import Optional, Tuple -_READER = None -_WRITER = None +_READER: Optional["MockReader"] = None +_WRITER: Optional["MockWriter"] = None _RETURN_VAL = "".encode("ascii") _PROMPT = "".encode("ascii") _LINEBREAK = float("inf") -_NEXT_EXCEPTION = None +_NEXT_EXCEPTION: Optional[Exception] = None class MockWriter: @@ -45,13 +46,16 @@ def set_linebreak(self, linebreak: int): def set_cmd(self, new_cmd: bytes): # The asyncio telnet connection adds '\r\rn' commands for every # strings bigger than the linebreak. So let's add that here. - self._cmd = "\r\r\n".join( - textwrap.wrap( - _PROMPT.decode("utf-8") + " " + new_cmd.decode("utf-8"), - width=_LINEBREAK, - drop_whitespace=False, - ) - ).encode("ascii") + try: + self._cmd = "\r\r\n".join( + textwrap.wrap( + _PROMPT.decode("utf-8") + " " + new_cmd.decode("utf-8"), + width=int(_LINEBREAK), + drop_whitespace=False, + ) + ).encode("ascii") + except OverflowError: + self._comd = new_cmd async def readuntil(self, read_till: bytes) -> bytes: # Let's create the return string from the cmd and the return string @@ -67,6 +71,7 @@ def set_prompt(new_prompt): def set_return(new_return: str): global _RETURN_VAL + print(f"set reutrn: {new_return}") _RETURN_VAL = new_return.encode("ascii") @@ -80,8 +85,9 @@ def raise_exception_on_write(exception_type): _NEXT_EXCEPTION = exception_type -async def open_connection(*args, **kwargs) -> typing.Tuple[MockReader, MockWriter]: +async def open_connection(*args, **kwargs) -> Tuple[MockReader, MockWriter]: global _READER, _WRITER + print("MOCKED OPEN") _READER = MockReader() _WRITER = MockWriter() # Clear previously configured variables. diff --git a/tests/test_connection.py b/tests/test_connection.py index 310daf0..d714dec 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -234,9 +234,9 @@ class TestTelnetConnection(TestCase): def setUp(self): """Set up test env.""" - self.connection = TelnetConnection("fake", "fake", "fake", "fake") + self.connection = TelnetConnection("fake", 2, "fake", "fake") # self.connection._connected = True - self.connection._prompt_string = "" + self.connection._prompt_string = "".encode("ascii") def test_determine_linelength_inf(self): """ Test input for infinite breakline length.""" @@ -244,27 +244,27 @@ def test_determine_linelength_inf(self): # The input string is shorter than the limit for i in (15, 50): input_bytes = (" " * i).encode("ascii") - self.connection._determine_linebreak(input_bytes) - self.assertEqual(self.connection._linebreak, float("inf")) + linebreak = self.connection._determine_linebreak(input_bytes) + self.assertEqual(linebreak, float("inf")) def test_determine_linelength(self): for i in (15, 50): input_bytes = (" " * i + "\n" + " " * 5).encode("ascii") - self.connection._determine_linebreak(input_bytes) - self.assertEqual(self.connection._linebreak, i) + linebreak = self.connection._determine_linebreak(input_bytes) + self.assertEqual(linebreak, i) # And now with some more lines input_bytes = ((" " * i + "\n") * 3 + " " * 5).encode("ascii") - self.connection._determine_linebreak(input_bytes) - self.assertEqual(self.connection._linebreak, i) + linebreak = self.connection._determine_linebreak(input_bytes) + self.assertEqual(linebreak, i) # And with a prompt string prompt = "test_string" - input_bytes = "a" * (i - len(prompt)) + "\n" + "a" * 5 - self.connection._prompt_string = prompt - self.connection._determine_linebreak(input_bytes.encode("ascii")) - self.assertEqual(self.connection._linebreak, i) - self.connection._prompt_string = "" + input_bytes = ("a" * (i - len(prompt)) + "\n" + "a" * 5).encode("ascii") + self.connection._prompt_string = prompt.encode("ascii") + linebreak = self.connection._determine_linebreak(input_bytes) + self.assertEqual(linebreak, i) + self.connection._prompt_string = "".encode("ascii") @pytest.mark.asyncio @@ -273,23 +273,28 @@ async def test_sending_cmds(): # Let's set a short linebreak of 10 telnet_mock.set_linebreak(22) - connection = TelnetConnection("fake", "fake", "fake", "fake") + connection = TelnetConnection("fake", 2, "fake", "fake") + print("Doing connection") await connection.async_connect() + print("Fin connection") # Now let's send some arbitrary short command exp_ret_val = "Some arbitrary long return string." + "." * 100 telnet_mock.set_return(exp_ret_val) new_return = await connection.async_run_command("run command\n") + print(new_return) assert new_return[0] == exp_ret_val @pytest.mark.asyncio async def test_reconnect(): with mock.patch("asyncio.open_connection", new=telnet_mock.open_connection): - connection = TelnetConnection("fake", "fake", "fake", "fake") + connection = TelnetConnection("fake", 2, "fake", "fake") await connection.async_connect() - telnet_mock.raise_exception_on_write(IncompleteReadError("", 42)) + telnet_mock.raise_exception_on_write( + IncompleteReadError("".encode("ascii"), 42) + ) new_return = await connection.async_run_command("run command\n") assert new_return == [""]