From 0b330d15040fe216241dc10158777ce811299276 Mon Sep 17 00:00:00 2001 From: Yudai Fujiwara Date: Mon, 22 Apr 2024 20:10:40 +0900 Subject: [PATCH 01/12] Refactor tube/sock --- ptrlib/binary/encoding/byteconv.py | 4 + ptrlib/connection/proc.py | 239 --------- ptrlib/connection/sock.py | 228 +++++--- ptrlib/connection/ssh.py | 66 --- ptrlib/connection/tube.py | 812 +++++++++++++++++++---------- ptrlib/connection/winproc.py | 250 --------- 6 files changed, 672 insertions(+), 927 deletions(-) diff --git a/ptrlib/binary/encoding/byteconv.py b/ptrlib/binary/encoding/byteconv.py index ac0f587..7613f90 100644 --- a/ptrlib/binary/encoding/byteconv.py +++ b/ptrlib/binary/encoding/byteconv.py @@ -9,6 +9,8 @@ def bytes2str(data: bytes) -> str: """ if isinstance(data, bytes): return ''.join(list(map(chr, data))) + elif isinstance(data, str): + return data # Fallback else: raise ValueError("{} given ('bytes' expected)".format(type(data))) @@ -20,6 +22,8 @@ def str2bytes(data: str) -> bytes: return bytes(list(map(ord, data))) except ValueError: return data.encode('utf-8') + elif isinstance(data, bytes): + return data # Fallback else: raise ValueError("{} given ('str' expected)".format(type(data))) diff --git a/ptrlib/connection/proc.py b/ptrlib/connection/proc.py index bed3244..e69de29 100644 --- a/ptrlib/connection/proc.py +++ b/ptrlib/connection/proc.py @@ -1,239 +0,0 @@ -# coding: utf-8 -from logging import getLogger -from typing import Any, List, Mapping -from ptrlib.arch.linux.sig import * -from ptrlib.binary.encoding import * -from .tube import * -from .winproc import * -import errno -import select -import os -import subprocess -import time - -_is_windows = os.name == 'nt' -if not _is_windows: - import fcntl - import pty - import tty - -logger = getLogger(__name__) - - -class UnixProcess(Tube): - def __init__( - self, - args: Union[Union[bytes, str], List[Union[bytes, str]]], - env: Optional[Union[Mapping[bytes, Union[bytes, str]], Mapping[str, Union[bytes, str]]]]=None, - cwd: Optional[Union[bytes, str]]=None, - timeout: Optional[int]=None - ): - """Create a process - - Create a new process and make a pipe. - - Args: - args (list): The arguments to pass - env (list) : The environment variables - - Returns: - Process: ``Process`` instance. - """ - assert not _is_windows - super().__init__() - - if isinstance(args, list): - self.args = args - self.filepath = args[0] - else: - self.args = [args] - self.filepath = args - self.env = env - self.default_timeout = timeout - self.timeout = self.default_timeout - self.proc = None - self.returncode = None - - # Open pty on Unix - master, self.slave = pty.openpty() - tty.setraw(master) - tty.setraw(self.slave) - - # Create a new process - try: - self.proc = subprocess.Popen( - self.args, - cwd = cwd, - env = self.env, - shell = False, - stdout=self.slave, - stderr=subprocess.STDOUT, - stdin=subprocess.PIPE - ) - except FileNotFoundError: - logger.warning("Executable not found: '{0}'".format(self.filepath)) - return - - # Duplicate master - if master is not None: - self.proc.stdout = os.fdopen(os.dup(master), 'r+b', 0) - os.close(master) - - # Set in non-blocking mode - fd = self.proc.stdout.fileno() - fl = fcntl.fcntl(fd, fcntl.F_GETFL) - fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK) - logger.info("Successfully created new process (PID={})".format(self.proc.pid)) - - def _settimeout(self, timeout: Optional[Union[int, float]]): - if timeout is None: - self.timeout = self.default_timeout - elif timeout > 0: - self.timeout = timeout - - def _socket(self) -> Optional[Any]: - return self.proc - - def _poll(self) -> Optional[int]: - if self.proc is None: - return False - - # Check if the process exits - self.proc.poll() - returncode = self.proc.returncode - if returncode is not None and self.returncode is None: - self.returncode = returncode - name = signal_name(-returncode, detail=True) - if name: name = '--> ' + name - logger.error( - "Process '{}' (pid={}) stopped with exit code {} {}".format( - self.filepath, self.proc.pid, returncode, name - )) - return returncode - - def is_alive(self) -> bool: - """Check if the process is alive""" - return self._poll() is None - - def _can_recv(self) -> bool: - """Check if receivable""" - if self.proc is None: - return False - - try: - r = select.select( - [self.proc.stdout], [], [], self.timeout - ) - if r == ([], [], []): - raise TimeoutError("Receive timeout", b'') - else: - # assert r == ([self.proc.stdout], [], []) - return True - except TimeoutError as e: - raise e from None - except select.error as v: - if v[0] == errno.EINTR: - return False - assert False, "unreachable" - - def _recv(self, size: int=4096, timeout: Optional[Union[int, float]]=None) -> bytes: - """Receive raw data - - Receive raw data of maximum `size` bytes length through the pipe. - - Args: - size (int): The data size to receive - timeout (int): Timeout (in second) - - Returns: - bytes: The received data - """ - self._settimeout(timeout) - - if not self._can_recv(): - return b'' - - try: - data = self.proc.stdout.read(size) - except subprocess.TimeoutExpired: - # TODO: Unreachable? - raise TimeoutError("Receive timeout", b'') from None - - self._poll() # poll after received all data - return data - - def _send(self, data: Union[str, bytes]): - """Send raw data - - Send raw data through the socket - - Args: - data (bytes) : Data to send - """ - self._poll() - if isinstance(data, str): - data = str2bytes(data) - elif not isinstance(data, bytes): - logger.warning("Expected 'str' or 'bytes' but {} given".format( - type(data) - )) - - try: - self.proc.stdin.write(data) - self.proc.stdin.flush() - except IOError: - logger.warning("Broken pipe") - - def close(self): - """Close the socket - - Close the socket. - This method is called from the destructor. - """ - if self.proc: - os.close(self.slave) - self.proc.stdin.close() - self.proc.stdout.close() - if self.is_alive(): - self.proc.kill() - self.proc.wait() - logger.info("'{0}' (PID={1}) killed".format(self.filepath, self.proc.pid)) - self.proc = None - else: - logger.info("'{0}' (PID={1}) has already exited".format(self.filepath, self.proc.pid)) - self.proc = None - - def shutdown(self, target: Literal['send', 'recv']): - """Kill one connection - - Close send/recv pipe. - - Args: - target (str): Connection to close (`send` or `recv`) - """ - if target in ['write', 'send', 'stdin']: - self.proc.stdin.close() - - elif target in ['read', 'recv', 'stdout', 'stderr']: - self.proc.stdout.close() - - else: - logger.error("You must specify `send` or `recv` as target.") - - def wait(self) -> int: - """Wait until the process dies - - Wait until the process exits and get the status code. - - Returns: - code (int): Status code of the process - """ - while self.is_alive(): - time.sleep(0.1) - return self.returncode - - def __del__(self): - self.close() - -Process = WinProcess if _is_windows else UnixProcess -process = Process # alias for the Process diff --git a/ptrlib/connection/sock.py b/ptrlib/connection/sock.py index c364a93..823cf5b 100644 --- a/ptrlib/connection/sock.py +++ b/ptrlib/connection/sock.py @@ -1,142 +1,187 @@ -# coding: utf-8 +import errno +import select +import socket from logging import getLogger - +from typing import Literal, Optional, Union from ptrlib.binary.encoding import * -from .tube import * -import socket - -logger = getLogger(__name__) +from .tube import Tube class Socket(Tube): - def __init__(self, host: Union[str, bytes], port: Optional[int]=None, - timeout: Optional[Union[int, float]]=None, - ssl: bool=False, sni: Union[str, bool]=True): + # + # Constructor + # + def __init__(self, + host: Union[str, bytes], + port: Optional[int]=None, + ssl: bool=False, + sni: Union[str, bool]=True, + **kwargs): """Create a socket Create a new socket and establish a connection to the host. Args: - host (str): The host name or ip address of the server - port (int): The port number + host: Host name or ip address + port: Port number + ssl : Enable SSL/TLS + sni : SNI Returns: Socket: ``Socket`` instance. """ super().__init__() - if isinstance(host, bytes): - host = bytes2str(host) - + # Interpret host name and port number + host = bytes2str(host) if port is None: host = host.strip() if host.startswith('nc '): _, a, b = host.split() - host, port = a, int(b) elif host.count(':') == 1: a, b = host.split(':') - host, port = a, int(b) elif host.count(' ') == 1: a, b = host.split() - host, port = a, int(b) else: - raise ValueError("Specify port number") + raise ValueError("Port number is not given") + host, port = a, int(b) + + else: + port = int(port) + + self._host = host + self._port = port - self.host = host - self.port = port - self.timeout = timeout # Create a new socket - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if ssl: import ssl as _ssl self.context = _ssl.SSLContext(_ssl.PROTOCOL_TLS_CLIENT) self.context.check_hostname = False self.context.verify_mode = _ssl.CERT_NONE if sni is True: - self.sock = self.context.wrap_socket(self.sock) + self._sock = self.context.wrap_socket(self._sock) else: - self.sock = self.context.wrap_socket(self.sock, server_hostname=sni) + self._sock = self.context.wrap_socket(self._sock, server_hostname=sni) + # Establish a connection try: - self.sock.connect((self.host, self.port)) - logger.info("Successfully connected to {0}:{1}".format(self.host, self.port)) + self._sock.connect((self._host, self._port)) + logger.info(f"Successfully connected to {self._host}:{self._port}") + except ConnectionRefusedError as e: - err = "Connection to {0}:{1} refused".format(self.host, self.port) - logger.warning(err) + logger.error(f"Connection to {self._host}:{self._port} refused") raise e from None - def _settimeout(self, timeout: Optional[Union[int, float]]): - if timeout is None: - self.sock.settimeout(self.timeout) - elif timeout > 0: - self.sock.settimeout(timeout) + # + # Implementation of Tube methods + # + def _settimeout_impl(self, + timeout: Union[int, float]): + """Set timeout - def _socket(self) -> Optional[socket.socket]: - return self.sock + Args: + timeout: Timeout in second + """ + self._sock.settimeout(timeout) - def _recv(self, size: int=4096, timeout: Optional[Union[int, float]]=None) -> bytes: + def _recv_impl(self, size: int) -> bytes: """Receive raw data Receive raw data of maximum `size` bytes length through the socket. Args: - size (int): The data size to receive - timeout (int): Timeout (in second) + size: Maximum data size to receive at once Returns: bytes: The received data - """ - self._settimeout(timeout) + Raises: + ConnectionAbortedError: Connection is aborted by process + ConnectionResetError: Connection is closed by peer + TimeoutError: Timeout exceeded + OSError: System error + """ try: - data = self.sock.recv(size) + data = self._sock.recv(size) + except socket.timeout: - raise TimeoutError("Receive timeout", b'') from None + raise TimeoutError("Timeout (_recv_impl)", b'') from None + except ConnectionAbortedError as e: - logger.warning("Connection aborted by the host") + logger.error("Connection aborted") raise e from None + except ConnectionResetError as e: - logger.warning("Connection reset by the host") + logger.error(f"Connection reset by {str(self)}") + raise e from None + + except OSError as e: + logger.error("OS Error") raise e from None return data - def _send(self, data: Union[str, bytes]): + def _send_impl(self, data: bytes) -> int: """Send raw data - Send raw data through the socket - - Args: - data (bytes) : Data to send + Raises: + ConnectionAbortedError: Connection is aborted by process + ConnectionResetError: Connection is closed by peer + TimeoutError: Timeout exceeded + OSError: System error """ - if isinstance(data, str): - data = str2bytes(data) - elif not isinstance(data, bytes): - logger.warning("Expected 'str' or 'bytes' but {} given".format( - type(data) - )) + data = str2bytes(data) try: - self.sock.send(data) + self._sock.send(data) + except BrokenPipeError as e: - logger.warning("Broken pipe") + logger.error("Broken pipe") raise e from None + except ConnectionAbortedError as e: - logger.warning("Connection aborted by the host") + logger.error("Connection aborted") + raise e from None + + except ConnectionResetError as e: + logger.error(f"Connection reset by {str(self)}") raise e from None - def close(self): - """Close the socket + except OSError as e: + logger.error("OS Error") + raise e from None - Close the socket. - This method is called from the destructor. + def _close_impl(self): + """Close socket """ - if self.sock: - self.sock.close() - self.sock = None - logger.info("Connection to {0}:{1} closed".format(self.host, self.port)) + self._sock.close() + + def _is_alive_impl(self) -> bool: + """Check if socket is alive + """ + try: + # Save timeout value since non-blocking mode will clear it + timeout = self._sock.gettimeout() + self._sock.setblocking(False) + + # Connection is closed if recv returns empty buffer + ret = len(self._sock.recv(1, socket.MSG_PEEK)) == 1 + + except BlockingIOError as e: + ret = True + + except (ConnectionResetError, socket.timeout): + ret = False - def shutdown(self, target: Literal['send', 'recv']): + finally: + self._sock.setblocking(True) + self._settimeout_impl(timeout) + + return ret + + def _shutdown_impl(self, target: Literal['send', 'recv']): """Kill one connection Close send/recv socket. @@ -145,28 +190,43 @@ def shutdown(self, target: Literal['send', 'recv']): target (str): Connection to close (`send` or `recv`) """ if target in ['write', 'send', 'stdin']: - self.sock.shutdown(socket.SHUT_WR) + self._sock.shutdown(socket.SHUT_WR) elif target in ['read', 'recv', 'stdout', 'stderr']: - self.sock.shutdown(socket.SHUT_RD) + self._sock.shutdown(socket.SHUT_RD) else: - logger.error("You must specify `send` or `recv` as target.") + raise ValueError("`target` must either 'send' or 'recv'") - def is_alive(self, timeout: Optional[Union[int, float]]=None) -> bool: - try: - self._settimeout(timeout) - data = self.sock.recv(1, socket.MSG_PEEK) - return True - except BlockingIOError: - return False - except ConnectionResetError: - return False - except socket.timeout: - return False + def __str__(self) -> str: + return f"{self._host}:{self._port}" + + + # + # Custom methods + # + def set_keepalive(self, + keep_idle: Optional[Union[int, float]]=None, + keep_interval: Optional[Union[int, float]]=None, + keep_count: Optional[Union[int, float]]=None): + """Set TCP keep-alive mode + + Send a keep-alive ping once every `keep_interval` seconds if activates + after `keep_idle` seconds of idleness, and closes the connection + after `keep_count` failed ping. + + Args: + keep_idle : Maximum duration to wait before sending keep-alive ping in second + keep_interval: Interval to send keep-alive ping in second + keep_count : Maximum number of failed attempts + """ + self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + if keep_idle is not None: + self._sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, keep_idle) + if keep_interval is not None: + self._sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, keep_interval) + if keep_count is not None: + self._sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, keep_count) - def __del__(self): - self.close() -# alias remote = Socket diff --git a/ptrlib/connection/ssh.py b/ptrlib/connection/ssh.py index 8eb7cdd..e69de29 100644 --- a/ptrlib/connection/ssh.py +++ b/ptrlib/connection/ssh.py @@ -1,66 +0,0 @@ -# coding: utf-8 -import shlex -import os -from ptrlib.binary.encoding import * -from ptrlib.arch.common import which -from .proc import * - -if os.name == 'nt': - _is_windows = True -else: - _is_windows = False - - -def SSH(host: str, port: int, username: str, - password: Optional[str]=None, identity: Optional[str]=None, - ssh_path: Optional[str]=None, expect_path: Optional[str]=None, - option: str='', command: str=''): - """Create an SSH shell - - Create a new process to connect to SSH server - - Args: - host (str) : SSH hostname - port (int) : SSH port - username (str): SSH username - password (str): SSH password - identity (str): Path of identity file - option (str) : Parameters to pass to SSH - command (str) : Initial command to execute on remote - - Returns: - Process: ``Process`` instance. - """ - assert isinstance(port, int) - if password is None and identity is None: - raise ValueError("You must give either password or identity") - - if ssh_path is None: - ssh_path = which('ssh') - if expect_path is None: - expect_path = which('expect') - - if not os.path.isfile(ssh_path): - raise FileNotFoundError("{}: SSH not found".format(ssh_path)) - if not os.path.isfile(expect_path): - raise FileNotFoundError("{}: 'expect' not found".format(expect_path)) - - if identity is not None: - option += ' -i {}'.format(shlex.quote(identity)) - - script = 'eval spawn {} -oStrictHostKeyChecking=no -oCheckHostIP=no {}@{} -p{} {} {}; interact; lassign [wait] pid spawnid err value; exit "$value"'.format( - ssh_path, - shlex.quote(username), - shlex.quote(host), - port, - option, - command - ) - - proc = Process( - [expect_path, '-c', script], - ) - if identity is None: - proc.sendlineafter("password: ", password) - - return proc diff --git a/ptrlib/connection/tube.py b/ptrlib/connection/tube.py index ad20b65..ce28e41 100644 --- a/ptrlib/connection/tube.py +++ b/ptrlib/connection/tube.py @@ -1,117 +1,166 @@ -# coding: utf-8 -import subprocess -from typing import Any, Optional, List, Tuple, Union, overload -try: - from typing import Literal -except: - from typing_extensions import Literal -from ptrlib.binary.encoding import * -from ptrlib.console.color import Color -from abc import ABCMeta, abstractmethod +import abc import re +import select import sys import threading -import time from logging import getLogger +from typing import List, Literal, Optional, Tuple, Union +from ptrlib.binary.encoding import bytes2str, str2bytes, bytes2hex, bytes2utf8 +from ptrlib.console.color import Color logger = getLogger(__name__) -class Tube(metaclass=ABCMeta): - def __init__(self): - self.buf = b'' - self.debug = False +class Tube(metaclass=abc.ABCMeta): + """Abstract class for streaming data - @abstractmethod - def _settimeout(self, timeout: Optional[Union[int, float]]): - """Set timeout + A child class must implement the following methods: - Args: - timeout (float): Timeout (None: Set to default / -1: No change / x>0: Set timeout to x seconds) + - "_settimeout_impl" + - "_recv_impl" + - "_send_impl" + - "_close_impl" + - "_is_alive_impl + - "_shutdown_impl" + """ + # + # Decorator + # + def not_closed(method): + """Ensure that socket is not *explicitly* closed + """ + def decorator(*args, **kwargs): + assert isinstance(args[0], Tube), "Invalid usage of decorator" + if args[0]._is_closed: + raise BrokenPipeError("Socket has already been closed") + return method(*args, **kwargs) + return decorator + + # + # Constructor + # + def __init__(self, + timeout: Optional[Union[int, float]]=None): """ - pass - - @abstractmethod - def _recv(self, size: int, timeout: Union[int, float]) -> Optional[bytes]: - """Receive raw data - - Receive raw data of maximum `size` bytes length through the socket. - Args: - size (int): The data size to receive - timeout (int): Timeout (in second) - - Returns: - bytes: The received data + timeout (float): Default timeout """ - pass + self._buffer = b'' - def unget(self, data: Union[str, bytes]): - """Revert data to socket + self._is_closed = False - Return data to socket. + self._default_timeout = timeout + self.settimeout() + # + # Methods + # + @not_closed + def settimeout(self, timeout: Optional[Union[int, float]]=None): + """Set timeout + Args: - data (bytes): Data to return + timeout (float): Timeout in second + + Note: + Set timeout to None in order to set the default timeout) + + Examples: + ``` + p = Socket("0.0.0.0", 1337, timeout=3) + # ... + p.settimeout(5) # Timeout is set to 5 + # ... + p.settimeout() # Timeout is set to 3 + ``` """ - if isinstance(data, str): - data = str2bytes(data) - self.buf = data + self.buf + assert timeout is None or (isinstance(timeout, (int, float)) and timeout >= 0), \ + "`timeout` must be positive and either int or float" - def recv(self, size: int=4096, timeout: Optional[Union[int, float]]=None) -> bytes: - """Receive raw data with buffering + if timeout is None: + if self._default_timeout is not None: + self._settimeout_impl(self._default_timeout) + else: + self._settimeout_impl(timeout) - Receive raw data of maximum `size` bytes length through the socket. + def recv(self, + size: int=4096, + timeout: Optional[Union[int, float]]=None) -> bytes: + """Receive data with buffering + + Receive raw data of at most `size` bytes. Args: - size (int): The data size to receive (Use `recvonce` - if you want to read exactly `size` bytes) - timeout (int): Timeout (in second) + size : Size to receive (Use `recvonce` to read exactly `size` bytes) + timeout: Timeout in second Returns: - bytes: The received data - """ - if size <= 0: - raise ValueError("`size` must be larger than 0") + bytes: Received data + + Raises: + ConnectionAbortedError: Connection is aborted by process + ConnectionResetError: Connection is closed by peer + TimeoutError: Timeout exceeded + OSError: System error - elif len(self.buf) == 0: - self._settimeout(timeout) + Examples: + ``` + tube.recv(4) try: - data = self._recv(size, timeout=-1) - except TimeoutError as err: - raise TimeoutError("`recv` timeout", b'') + tube.recv(timeout=3.14) + except TimeoutError: + pass + ``` + """ + assert size is None or (isinstance(size, int) and size >= 0), \ + "`size` must be a positive integer" + + # NOTE: We always return buffer if it's not empty + # This is because we do not know how many bytes we can read. + if len(self._buffer): + data, self._buffer = self._buffer[:size], self._buffer[size:] + return data - self.buf += data - if self.debug: - logger.info(f"Received {hex(len(data))} ({len(data)}) bytes:") - hexdump(data, prefix=" " + Color.CYAN, postfix=Color.END) + if timeout is not None: + self.settimeout(timeout) - # We don't check size > len(self.buf) because Python handles it - data, self.buf = self.buf[:size], self.buf[size:] + try: + self._buffer += self._recv_impl(size - len(self._buffer)) + + except TimeoutError as err: + data = self._buffer + err.args[1] + self._buffer = b'' + raise TimeoutError("Timeout (recv)", data) + + finally: + if timeout is not None: + # Reset timeout to default value + self.settimeout() + + data, self._buffer = self._buffer[:size], self._buffer[size:] return data - def recvonce(self, size: int, timeout: Optional[Union[int, float]]=None) -> bytes: - """Receive raw data with buffering + def recvonce(self, + size: int, + timeout: Optional[Union[int, float]]=None) -> bytes: + """Receive raw data of exact size with buffering - Receive raw data of size `size` bytes length through the socket. + Receive raw data of exactly `size` bytes. Args: - size (int): The data size to receive - timeout (int): Timeout (in second) + size : Data size to receive + timeout: Timeout in second Returns: - bytes: The received data + bytes: Received data """ - self._settimeout(timeout) data = b'' - timer_start = time.time() while len(data) < size: try: - data += self.recv(size - len(data), timeout=-1) + data += self.recv(size - len(data), timeout) except TimeoutError as err: - raise TimeoutError("`recvonce` timeout", data + err.args[1]) - time.sleep(0.01) + raise TimeoutError("Timeout (recvonce)", data + err.args[1]) if len(data) > size: self.unget(data[size:]) @@ -122,159 +171,187 @@ def recvuntil(self, size: int=4096, timeout: Optional[Union[int, float]]=None, drop: bool=False, - lookahead: bool=False, - sleep_time: float=0.01) -> bytes: + lookahead: bool=False) -> bytes: """Receive raw data until `delim` comes Args: - delim (bytes): The delimiter bytes - size (int) : The data size to receive at once - timeout (int): Timeout (in second) - drop (bool): Discard delimiter or not - lookahead (bool): Unget delimiter to buffer or not - sleep_time (float): Sleep time after receiving data + delim : The delimiter bytes + size : The data size to receive at once + timeout : Timeout in second + drop : Discard delimiter or not + lookahead: Unget delimiter to buffer or not Returns: bytes: Received data + + Raises: + ConnectionAbortedError: Connection is aborted by process + ConnectionResetError: Connection is closed by peer + TimeoutError: Timeout exceeded + OSError: System error + + Examples: + ``` + echo.sendline("abc123def") + echo.recvuntil("123") # abc123 + + echo.sendline("abc123def") + echo.recvuntil("123", drop=True) # abc + + echo.sendline("abc123def") + echo.recvuntil("123", lookahead=True) # abc123 + echo.recvonce(6) # 123def + ``` """ - # Validate and normalize delimiter - if isinstance(delim, bytes): - delim = [delim] - elif isinstance(delim, str): - delim = [str2bytes(delim)] - elif isinstance(delim, list): - for i, t in enumerate(delim): - if isinstance(t, str): - delim[i] = str2bytes(t) - elif not isinstance(t, bytes): - raise ValueError(f"Delimiter must be either string or bytes: {t}") + assert isinstance(delim, (str, bytes, list)), \ + "`delim` must be either str, bytes, or list" + + # Preprocess + if isinstance(delim, list): + for i, d in enumerate(delim): + assert isinstance(d, (str, bytes)), \ + f"`delim[{i}]` must be either str or bytes" + delim[i] = str2bytes(delim) else: - raise ValueError(f"Delimiter must be either string, bytes, or list: {t}") + delim = [str2bytes(delim)] - self._settimeout(timeout) + # Iterate until we find one of the delimiters + found_delim = None + prev_len = 0 data = b'' - timer_start = time.time() - - found = False - token = None while True: try: - data += self.recv(size, timeout=-1) + data += self.recv(size, timeout) except TimeoutError as err: - raise TimeoutError("`recvuntil` timeout", data + err.args[1]) + raise TimeoutError("Timeout (recvuntil)", data + err.args[1]) - for t in delim: - if t in data: - found = True - token = t + for d in delim: + if d in data[max(0, prev_len-len(d)):]: + found_delim = d break - - if found: + if found_delim is not None: break - if sleep_time: - time.sleep(sleep_time) - found_pos = data.find(token) - result_len = found_pos if drop else found_pos + len(token) - consumed_len = found_pos if lookahead else found_pos + len(token) - self.unget(data[consumed_len:]) - return data[:result_len] + prev_len = len(data) + + i = data.find(found_delim) + j = i + len(found_delim) + if not drop: + i = j + + ret, data = data[:i], data[j:] + self.unget(data) + if lookahead: + self.unget(found_delim) + + return ret def recvline(self, size: int=4096, timeout: Optional[Union[int, float]]=None, - drop: bool=True) -> bytes: + drop: bool=True, + lookahead: bool=False) -> bytes: """Receive a line of data Args: - size (int) : The data size to receive at once - timeout (int): Timeout (in second) - drop (bool) : Discard delimiter or not + size : The data size to receive at once + timeout : Timeout (in second) + drop : Discard trailing newlines or not + lookahead: Unget trailing newline to buffer or not Returns: bytes: Received data """ - line = self.recvuntil(b'\n', size, timeout) - if drop: - return line.rstrip() - return line + try: + line = self.recvuntil(b'\n', size, timeout, lookahead=lookahead) + except TimeoutError as err: + raise TimeoutError("Timeout (recvline)", err.args[1]) + + return line.rstrip() if drop else line def recvlineafter(self, delim: Union[str, bytes], size: int=4096, timeout: Optional[Union[int, float]]=None, - drop: bool=True) -> bytes: + drop: bool=True, + lookahead: bool=False) -> bytes: """Receive a line of data after receiving `delim` Args: - delim (bytes): The delimiter bytes - size (int) : The data size to receive at once - timeout (int): Timeout (in second) - drop (bool) : Discard delimiter or not + delim : The delimiter bytes + size : The data size to receive at once + timeout : Timeout (in second) + drop : Discard trailing newline or not + lookahead: Unget trailing newline to buffer or not Returns: bytes: Received data - """ - self.recvuntil(delim, size, timeout) - return self.recvline(size, timeout, drop) - # TODO: proper typing - @overload - def recvregex(self, regex: Union[str, bytes], size: int=4096, discard: Literal[True]=True, timeout: Optional[Union[int, float]]=None) -> bytes: ... + Raises: + ConnectionAbortedError: Connection is aborted by process + ConnectionResetError: Connection is closed by peer + TimeoutError: Timeout exceeded + OSError: System error + """ + try: + self.recvuntil(delim, size, timeout) + except TimeoutError as err: + # NOTE: We do not set received value here + raise TimeoutError("Timeout (recvlineafter)", b'') - @overload - def recvregex(self, regex: Union[str, bytes], size: int=4096, discard: Literal[False]=False, timeout: Optional[Union[int, float]]=None) -> Tuple[bytes, bytes]: ... + try: + return self.recvline(size, timeout, drop, lookahead) + except TimeoutError as err: + raise TimeoutError("Timeout (recvlineafter)", err.args[1]) def recvregex(self, - regex: Union[str, bytes], + regex: Union[str, bytes, re.Pattern], size: int=4096, - discard: bool=True, - timeout: Optional[Union[int, float]]=None) -> Union[bytes, Tuple[bytes, bytes]]: + timeout: Optional[Union[int, float]]=None) -> Union[bytes, Tuple[bytes, ...]]: """Receive until a pattern comes Receive data until a specified regex pattern matches. Args: - regex (bytes) : Regex - size (int) : Size to read at once - discard (bool): Discard received bytes or not - timeout (int) : Timeout (in second) + regex : Regular expression + size : Size to read at once + timeout: Timeout in second Returns: tuple: If the given regex has multiple patterns to find, it returns all matches. Otherwise, it returns the - match string. If discard is false, it also returns - all data received so far along with the matches. + matched string. + + Raises: + ConnectionAbortedError: Connection is aborted by process + ConnectionResetError: Connection is closed by peer + TimeoutError: Timeout exceeded + OSError: System error """ - if not isinstance(regex, bytes): - regex = str2bytes(regex) + assert isinstance(regex, (str, bytes, re.Pattern)), \ + "`regex` must be either str, bytes, or re.Pattern" - p = re.compile(regex) - data = b'' + if isinstance(regex, str): + regex = re.compile(str2bytes(regex)) - self._settimeout(timeout) - r = None - while r is None: - data += self.recv(size, timeout=-1) - r = p.search(data) + data = b'' + match = None + while match is None: + try: + data += self.recv(size, timeout) + except TimeoutError as err: + raise TimeoutError("Timeout (recvregex)", data + err.args[1]) + match = regex.search(data) - pos = r.end() - self.unget(data[pos:]) + self.unget(data[match.end():]) - group = r.group() - groups = r.groups() - if groups: - if discard: - return groups - else: - return groups, data[:pos] + if match.groups(): + return match.groups() else: - if discard: - return group - else: - return group, data[:pos] + return match.group() - def recvscreen(self, delim: Optional[bytes]=b'\x1b[H', + def recvscreen(self, + delim: Optional[Union[str, bytes]]=b'\x1b[H', returns: Optional[type]=str, timeout: Optional[Union[int, float]]=None, timeout2: Optional[Union[int, float]]=1): @@ -283,104 +360,126 @@ def recvscreen(self, delim: Optional[bytes]=b'\x1b[H', Receive a screen drawn by ncurses Args: - delim (bytes) : Refresh sequence - returns (type): Return value as string or list - timeout (int) : Timeout to receive the first delimiter - timeout2 (int): Timeout to receive the second delimiter + delim : Refresh sequence + returns : Return value as string or list + timeout : Timeout to receive the first delimiter + timeout2: Timeout to receive the second delimiter Returns: str: Rectangle string drawing the screen + + Raises: + ConnectionAbortedError: Connection is aborted by process + ConnectionResetError: Connection is closed by peer + TimeoutError: Timeout exceeded + OSError: System error """ - self.recvuntil(delim, timeout=timeout) + assert returns in [list, str, bytes], \ + "`returns` must be either list, str, or bytes" + try: - buf = self.recvuntil(delim, drop=True, lookahead=True, - timeout=timeout2) + self.recvuntil(delim, timeout=timeout) + except TimeoutError as err: + # NOTE: We do not set received value here + raise TimeoutError("Timeout (recvscreen)", b'') + + try: + buf = self.recvuntil(delim, drop=True, lookahead=True, timeout=timeout2) except TimeoutError as err: buf = err.args[1] - screen = draw_ansi(buf) - if returns == list: - return screen - elif returns == str: + screen = draw_ansi(buf) + if returns == str: return '\n'.join(map(lambda row: ''.join(row), screen)) elif returns == bytes: return b'\n'.join(map(lambda row: bytes(row), screen)) else: - raise TypeError("`returns` must be either list, str, or bytes") + return screen - @abstractmethod - def _send(self, data: bytes): - pass + def send(self, data: Union[str, bytes]) -> int: + """Send raw data - def send(self, data: bytes): - self._send(data) - if self.debug: - logger.info(f"Sent {hex(len(data))} ({len(data)}) bytes:") - hexdump(data, prefix=Color.YELLOW, postfix=Color.END) + Send as much data as possible. - @abstractmethod - def _socket(self) -> Optional[Any]: - pass + Args: + data: Data to send - def sendline(self, data: Union[str, bytes], timeout: Optional[Union[int, float]]=None): - """Send a line + Returns: + int: Length of sent data - Send a line of data. + Note: + It is NOT ensured that all data is sent. + Use `sendonce` to make sure the whole data is sent. - Args: - data (bytes) : Data to send - timeout (int): Timeout (in second) + Examples: + ``` + tube.send("Hello") + tube.send(b"\xde\xad\xbe\xef") + ``` """ - if isinstance(data, str): - data = str2bytes(data) - elif isinstance(data, int): - data = str(data).encode() + assert isinstance(data, (str, bytes)), "`data` must be either str or bytes" - self.send(data + b'\n') + return self._send_impl(str2bytes(data)) - def sendafter(self, delim: Union[str, bytes], data: Union[str, bytes, int], timeout: Optional[Union[int, float]]=None): - """Send raw data after a delimiter + def sendline(self, + data: Union[int, float, str, bytes], + timeout: Optional[Union[int, float]]=None): + """Send a line - Send raw data after `delim` is received. + Send a line of data. Args: - delim (bytes): The delimiter data (bytes) : Data to send timeout (int): Timeout (in second) - - Returns: - bytes: Received bytes before `delim` comes. """ - if isinstance(data, str): - data = str2bytes(data) - elif isinstance(data, int): - data = str(data).encode() + assert isinstance(data, (int, float, str, bytes)), \ + "`data` must be int, float, str, or bytes" - recv_data = self.recvuntil(delim, timeout=timeout) - self.send(data) + if isinstance(data, (int, float)): + data = str(data).encode() + else: + data = str2bytes(data) - return recv_data + self.send(data + b'\n') - def sendlineafter(self, delim: Union[str, bytes], data: Union[str, bytes, int], timeout: Optional[Union[int, float]]=None) -> bytes: + def sendafter(self, + delim: Union[str, bytes, List[Union[str, bytes]]], + data: Union[int, float, str, bytes], + size: int=4096, + timeout: Optional[Union[int, float]]=None, + drop: bool=False, + lookahead: bool=False) -> bytes: """Send raw data after a delimiter - Send raw data with newline after `delim` is received. + Send raw data after `delim` is received. Args: - delim (bytes): The delimiter - data (bytes) : Data to send - timeout (int): Timeout (in second) + delim : The delimiter + data : Data to send + size : Data size to receive at once + timeout : Timeout in second + drop : Discard delimiter or not + lookahead: Unget delimiter to buffer or not Returns: bytes: Received bytes before `delim` comes. + + Examples: + ``` + tube.sendafter("> ", p32(len(data)) + data) + tube.sendafter("command: ", 1) # b"1" is sent + ``` """ - if isinstance(data, str): - data = str2bytes(data) - elif isinstance(data, int): + assert isinstance(data, (int, float, str, bytes)), \ + "`data` must be int, float, str, or bytes" + + if isinstance(data, (int, float)): data = str(data).encode() + else: + data = str2bytes(data) - recv_data = self.recvuntil(delim, timeout=timeout) - self.sendline(data, timeout=timeout) + recv_data = self.recvuntil(delim, size, timeout, drop, lookahead) + self.send(data) return recv_data @@ -390,7 +489,7 @@ def sendctrl(self, name: str): Send control key given its name Args: - name (str): Name of the control key to send + name: Name of the control key to send """ if name.lower() in ['w', 'up']: self.send(b'\x1bOA') @@ -409,44 +508,77 @@ def sendctrl(self, name: str): else: raise ValueError(f"Invalid control key name: {name}") - def sh(self, timeout: Optional[Union[int, float]]=None): + def sh(self, + timeout: Optional[Union[int, float]]=None, + prompt: str="[ptrlib]$ ", + raw: bool=False): """Alias for interactive + + Args: + timeout: Timeout in second + prompt : Prompt string to show on input """ - self.interactive(timeout) + self.interactive(timeout, prompt, raw) - def interactive(self, timeout: Optional[Union[int, float]]=None): + def interactive(self, + timeout: Union[int, float]=1, + prompt: str="[ptrlib]$ ", + raw: bool=False): """Interactive mode + + Args: + timeout: Timeout in second + prompt : Prompt string to show on input """ - def thread_recv(): - prev_leftover = None + prompt = f"{Color.BOLD}{Color.BLUE}{prompt}{Color.END}" + + def pretty_print_hex(c: str): + sys.stdout.write(f'{Color.RED}\\x{ord(c):02x}{Color.END}') + + def pretty_print(data: bytes, prev: bytes=b''): + """Print data in a human-friendly way + """ + leftover = b'' + + if raw: + sys.stdout.write(bytes2str(data)) + + else: + utf8str, leftover, marker = bytes2utf8(data) + if len(utf8str) == 0 and prev == leftover: + utf8str = f'{Color.RED}{bytes2hex(leftover)}{Color.END}' + leftover = b'' + + for c, t in zip(utf8str, marker): + if t: + if 0x7f <= ord(c) < 0x100: + pretty_print_hex(c) + elif ord(c) not in [0x09, 0x0a, 0x0d] and \ + ord(c) < 0x20: + pretty_print_hex(c) + else: + sys.stdout.write(c) + else: + pretty_print_hex(c) + + sys.stdout.flush() + return leftover + + def thread_recv(flag: threading.Event): + """Receive data from tube and print to stdout + """ + leftover = b'' while not flag.isSet(): try: - data = self.recv(size=4096, timeout=0.1) - if data is not None: - utf8str, leftover, marker = bytes2utf8(data) - if len(utf8str) == 0 and prev_leftover == leftover: - # Print raw hex string with color - # if the data is invalid as UTF-8 - utf8str = '{red}{hexstr}{end}'.format( - red=Color.RED, - hexstr=bytes2hex(leftover), - end=Color.END - ) - leftover = None - - for c, t in zip(utf8str, marker): - if t == True: - sys.stdout.write(c) - else: - sys.stdout.write('{red}{hexstr}{end}'.format( - red=Color.RED, - hexstr=str2hex(c), - end=Color.END - )) - sys.stdout.flush() - prev_leftover = leftover - if leftover is not None: - self.unget(leftover) + sys.stdout.write(prompt) + sys.stdout.flush() + data = self.recv(timeout=timeout) + leftover = pretty_print(data, leftover) + + if not self.is_alive(): + logger.error(f"Connection closed by {str(self)}") + flag.set() + except TimeoutError: pass except EOFError: @@ -455,51 +587,155 @@ def thread_recv(): except ConnectionAbortedError: logger.error("Receiver EOF") break - time.sleep(0.1) - - flag = threading.Event() - th = threading.Thread(target=thread_recv) - th.setDaemon(True) - th.start() - try: + def thread_send(flag: threading.Event): + """Read user input and send it to tube + """ + #sys.stdout.write(f"{Color.BOLD}{Color.BLUE}{prompt}{Color.END}") + #sys.stdout.flush() while not flag.isSet(): - data = input("{bold}{blue}[ptrlib]${end} ".format( - bold=Color.BOLD, blue=Color.BLUE, end=Color.END - )) - if self._socket() is None: - logger.error("Connection already closed") - break - if data is None: + (ready, _, _) = select.select([sys.stdin], [], [], 0.1) + if not ready: continue + + try: + self.send(sys.stdin.readline()) + except (ConnectionResetError, ConnectionAbortedError, OSError): flag.set() - else: - try: - self.sendline(data) - except ConnectionAbortedError: - logger.error("Sender EOF") - break - time.sleep(0.1) + + flag = threading.Event() + th_recv = threading.Thread(target=thread_recv, args=(flag,)) + th_send = threading.Thread(target=thread_send, args=(flag,)) + th_recv.start() + th_send.start() + try: + th_recv.join() + th_send.join() except KeyboardInterrupt: + logger.warning("Intterupted by user") + sys.stdin.close() flag.set() - while th.is_alive(): - th.join(timeout = 0.1) - time.sleep(0.1) + def close(self): + """Close this connection + + Note: + This method can only be called once. + """ + self._close_impl() + self._is_closed = True + + def unget(self, data: Union[str, bytes]): + """Unshift data to buffer + + Args: + data: Data to revert + + Examples: + ``` + leak = tube.recvline().rstrip(b"> ") + tube.unget("> ") + # ... + tube.sendlineafter("> ", "1") + ``` + """ + assert isinstance(data, (str, bytes)), "`data` must be either str or bytes" + + self._buffer = str2bytes(data) + self._buffer + + def is_alive(self) -> bool: + """Check if connection is not closed + + Returns: + bool: False if connection is closed, otherwise True + + Examples: + ``` + while tube.is_alive(): + print(tube.recv()) + ``` + """ + return self._is_alive_impl() + + def shutdown(self, target: Literal['send', 'recv']): + """Kill one connection + + Args: + target (str): Connection to close (`send` or `recv`) + + Examples: + The following code shuts down input of remote. + ``` + tube.shutdown("send") + data = tube.recv() # OK + tube.send(b"data") # NG + ``` + + The following code shuts down output of remote. + ``` + tube.shutdown("recv") + tube.send(b"data") # OK + data = tube.recv() # NG + ``` + """ + return self._shutdown_impl(target) def __enter__(self): return self - def __exit__(self, e_type, e_value, traceback): - self.close() + def __exit__(self, _e_type, _e_value, _traceback): + if not self._is_closed: + self.close() - @abstractmethod - def is_alive(self) -> bool: + def __str__(self) -> str: + return "" + + # + # Abstract methods + # + @abc.abstractmethod + @not_closed + def _recv_impl(self, size: int) -> bytes: + """Abstract method for `recv` + + Receives at most `size` bytes from tube. + This method must be a blocking method. + """ pass - @abstractmethod - def close(self): + @abc.abstractmethod + @not_closed + def _send_impl(self, data: bytes) -> int: + """Abstract method for `send` + + Sends tube as much data as possible. + + Args: + data: Data to send + """ pass - @abstractmethod - def shutdown(self, target: Literal['send', 'recv']): + @abc.abstractmethod + @not_closed + def _close_impl(self): + """Abstract method for `close` + + Close the connection. + This method is ensured to be called only once. + """ + pass + + @abc.abstractmethod + @not_closed + def _is_alive_impl(self) -> bool: + """Abstract method for `is_alive` + + This method must return True iff the connection is alive. + """ + pass + + @abc.abstractmethod + @not_closed + def _shutdown_impl(self, target: Literal['send', 'recv']): + """Kill one connection + """ pass diff --git a/ptrlib/connection/winproc.py b/ptrlib/connection/winproc.py index d2f6561..e69de29 100644 --- a/ptrlib/connection/winproc.py +++ b/ptrlib/connection/winproc.py @@ -1,250 +0,0 @@ -# coding: utf-8 -from logging import getLogger -from typing import List, Mapping -from ptrlib.binary.encoding import * -from .tube import * -import ctypes -import os -import time - -_is_windows = os.name == 'nt' -if _is_windows: - import win32api - import win32con - import win32file - import win32pipe - import win32process - import win32security - -logger = getLogger(__name__) - - -class WinPipe(object): - def __init__(self, inherit_handle: bool=True): - """Create a pipe for Windows - - Create a new pipe - - Args: - inherit_handle (bool): Whether the child can inherit this handle - - Returns: - WinPipe: ``WinPipe`` instance. - """ - attr = win32security.SECURITY_ATTRIBUTES() - attr.bInheritHandle = inherit_handle - self.rp, self.wp = win32pipe.CreatePipe(attr, 0) - - @property - def handle0(self) -> int: - return self.get_handle('recv') - @property - def handle1(self) -> int: - return self.get_handle('send') - - def get_handle(self, name: Literal['recv', 'send']='recv') -> int: - """Get endpoint of this pipe - - Args: - name (str): Handle to get (`recv` or `send`) - """ - if name in ['read', 'recv', 'stdin']: - return self.rp - - elif name in ['write', 'send', 'stdout', 'stderr']: - return self.wp - - else: - logger.error("You must specify `send` or `recv` as target.") - - @property - def size(self) -> int: - """Get the number of bytes available to read on this pipe""" - # (lpBytesRead, lpTotalBytesAvail, lpBytesLeftThisMessage) - return win32pipe.PeekNamedPipe(self.handle0, 0)[1] - - def _recv(self, size: int=4096): - if size <= 0: - logger.error("`size` must be larger than 0") - return b'' - - buf = ctypes.create_string_buffer(size) - win32file.ReadFile(self.handle0, buf) - - return buf.raw - - def recv(self, size: int=4096, timeout: Optional[Union[int, float]]=None): - """Receive raw data - - Receive raw data of maximum `size` bytes length through the pipe. - - Args: - size (int): The data size to receive - timeout (int): Timeout (in second) - - Returns: - bytes: The received data - """ - start = time.time() - # Wait until data arrives - while self.size == 0: - # Check timeout - if timeout is not None and time.time() - start > timeout: - raise TimeoutError("Receive timeout") - time.sleep(0.01) - - return self._recv(min(self.size, size)) - - def send(self, data: bytes): - """Send raw data - - Send raw data through the socket - - Args: - data (bytes) : Data to send - timeout (int): Timeout (in second) - """ - win32file.WriteFile(self.handle1, data) - - def close(self): - """Cleanly close this pipe""" - win32api.CloseHandle(self.rp) - win32api.CloseHandle(self.wp) - - def __del__(self): - self.close() - -class WinProcess(Tube): - def __init__(self, args: Union[List[Union[str, bytes]], str], env: Optional[Mapping[str, str]]=None, cwd: Optional[str]=None, flags: int=0, timeout: Optional[Union[int, float]]=None): - """Create a process - - Create a new process and make a pipe. - - Args: - args (list): The arguments to pass - env (list) : The environment variables - - Returns: - Process: ``Process`` instance. - """ - assert _is_windows - super().__init__() - - if isinstance(args, list): - for i, arg in enumerate(args): - if isinstance(arg, bytes): - args[i] = bytes2str(arg) - self.args = ' '.join(args) - self.filepath = args[0] - - # Check if arguments are safe for Windows - for arg in args: - if '"' not in arg: continue - if arg[0] == '"' and arg[-1] == '"': continue - logger.error("You have to escape the arguments by yourself.") - logger.error("Be noted what you are executing is") - logger.error("> " + self.args) - - else: - self.args = args - - # Create pipe - self.stdin = WinPipe() - self.stdout = WinPipe() - self.default_timeout = timeout - self.timeout = timeout - self.proc = None - - # Create process - info = win32process.STARTUPINFO() - info.dwFlags = win32con.STARTF_USESTDHANDLES - info.hStdInput = self.stdin.handle0 - info.hStdOutput = self.stdout.handle1 - info.hStdError = self.stdout.handle1 - # (hProcess, hThread, dwProcessId, dwThreadId) - self.proc, _, self.pid, _ = win32process.CreateProcess( - None, self.args, # lpApplicationName, lpCommandLine - None, None, # lpProcessAttributes, lpThreadAttributes - True, flags, # bInheritHandles, dwCreationFlags - env, cwd, # lpEnvironment, lpCurrentDirectory - info # lpStartupInfo - ) - - logger.info("Successfully created new process (PID={})".format(self.pid)) - - def _settimeout(self, timeout: Optional[Union[int, float]]): - """Set timeout value""" - if timeout is None: - self.timeout = self.default_timeout - elif timeout > 0: - self.timeout = timeout - - def _socket(self): - return self.proc - - def _recv(self, size: int, timeout: Optional[Union[int, float]]=None) -> bytes: - """Receive raw data - - Receive raw data of maximum `size` bytes length through the pipe. - - Args: - size (int): The data size to receive - timeout (int): Timeout (in second) - - Returns: - bytes: The received data - """ - self._settimeout(timeout) - if size <= 0: - logger.error("`size` must be larger than 0") - return b'' - - buf = self.stdout.recv(size, self.timeout) - return buf - - def is_alive(self) -> bool: - """Check if process is alive - - Returns: - bool: True if process is alive, otherwise False - """ - if self.proc is None: - return False - else: - status = win32process.GetExitCodeProcess(self.proc) - return status == win32con.STILL_ACTIVE - - def close(self): - if self.proc: - win32api.TerminateProcess(self.proc, 0) - win32api.CloseHandle(self.proc) - self.proc = None - logger.info("Process killed (PID={0})".format(self.pid)) - - def _send(self, data: bytes): - """Send raw data - - Send raw data through the socket - - Args: - data (bytes) : Data to send - """ - self.stdin.send(data) - - def shutdown(self, target: Literal['send', 'recv']): - """Close a connection - - Args: - target (str): Pipe to close (`recv` or `send`) - """ - if target in ['write', 'send', 'stdin']: - self.stdin.close() - - elif target in ['read', 'recv', 'stdout', 'stderr']: - self.stdout.close() - - else: - logger.error("You must specify `send` or `recv` as target.") - - def __del__(self): - self.close() From cb75233c6f841460327b8dd485af17eb1a222081 Mon Sep 17 00:00:00 2001 From: Yudai Fujiwara Date: Tue, 23 Apr 2024 15:08:08 +0900 Subject: [PATCH 02/12] Support multi-threading in socket --- ptrlib/connection/proc.py | 211 ++++++++++++++++++++++++++++++++++++++ ptrlib/connection/sock.py | 46 +++++---- ptrlib/connection/tube.py | 160 ++++++++++++++++++++++------- 3 files changed, 362 insertions(+), 55 deletions(-) diff --git a/ptrlib/connection/proc.py b/ptrlib/connection/proc.py index e69de29..3aeb171 100644 --- a/ptrlib/connection/proc.py +++ b/ptrlib/connection/proc.py @@ -0,0 +1,211 @@ +import os +import select +import subprocess +from logging import getLogger +from typing import List, Literal, Mapping, Optional, Union +from ptrlib.arch.linux.sig import signal_name +from ptrlib.binary.encoding import bytes2str, str2bytes +from .tube import Tube, tube_is_open + + +_is_windows = os.name == 'nt' +if not _is_windows: + import fcntl + import pty + import tty + +logger = getLogger(__name__) + +class UnixProcess(Tube): + # + # Constructor + # + def __init__(self, + args: Union[bytes, str, List[Union[bytes, str]]], + env: Optional[Union[Mapping[bytes, Union[bytes, str]], Mapping[str, Union[bytes, str]]]]=None, + cwd: Optional[Union[bytes, str]]=None, + shell: Optional[bool]=None, + raw: bool=False, + stdin : Optional[int]=None, + stdout: Optional[int]=None, + stderr: Optional[int]=None, + **kwargs): + """Create a UNIX process + + Create a UNIX process and make a pipe. + + Args: + args : The arguments to pass + env : The environment variables + cwd : Working directory + shell : If true, `args` is a shell command + raw : Disable pty if this parameter is true + stdin : File descriptor of standard input + stdout : File descriptor of standard output + stderr : File descriptor of standard error + + Returns: + Process: ``Process`` instance + + Examples: + ``` + p = Process("/bin/ls", cwd="/tmp") + p = Process(["wget", "www.example.com"], + stderr=subprocess.DEVNULL) + p = Process("cat /proc/self/maps", env={"LD_PRELOAD": "a.so"}) + ``` + """ + assert not _is_windows, "UnixProcess cannot work on Windows" + assert isinstance(args, (str, bytes, list)), \ + "`args` must be either str, bytes, or list" + assert env is None or isinstance(env, dict), \ + "`env` must be a dictionary" + assert cwd is None or isinstance(cwd, (str, bytes)), \ + "`cwd` must be either str or bytes" + + super().__init__(**kwargs) + + # Guess shell mode based on args + if shell is None: + if isinstance(args, (str, bytes)): + args = [bytes2str(args)] + if ' ' in args[0]: + shell = True + logger.info("Detected whitespace in arguments: " \ + "`shell=True` enabled") + else: + shell = False + else: + shell = False + + else: + if isinstance(args, (str, bytes)): + args = [bytes2str(args)] + else: + args = list(map(bytes2str, args)) + + # Prepare stdio + if raw: + pass + else: + master, self._slave = pty.openpty() + tty.setraw(master) + tty.setraw(self._slave) + + if stdin is None: stdin = subprocess.PIPE + if stdout is None: stdout = subprocess.PIPE + if stderr is None: stderr = subprocess.STDOUT + + # Open process + assert isinstance(shell, bool), "`shell` must be boolean" + try: + self._proc = subprocess.Popen( + args, cwd=cwd, env=env, + shell=shell, + stdin=stdin, + stdout=stdout, + stderr=stderr, + ) + except FileNotFoundError as err: + logger.error(f"Could not execute {args[0]}") + raise err from None + + self._filepath = args[0] + + self._returncode = None + self._current_timeout = self._default_timeout + + # + # Properties + # + @property + def returncode(self) -> Optional[int]: + return self._returncode + + # + # Implementation of Tube methods + # + def _settimeout_impl(self, timeout: Union[int, float]): + self._current_timeout = timeout + + def _recv_impl(self, size: int) -> bytes: + """Receive raw data + + Receive raw data of maximum `size` bytes through the pipe. + + Args: + size: Data size to receive + + Returns: + bytes: The received data + """ + ready, [], [] = select.select( + [self._proc.stdout], [], [], self._current_timeout + ) + if len(ready) == 0: + raise TimeoutError("Timeout (_recv_impl)", b'') from None + + try: + data = self._proc.stdout.read(size) + except subprocess.TimeoutExpired: + raise TimeoutError("Timeout (_recv_impl)", b'') from None + + return data + + def _send_impl(self, data: bytes) -> int: + return 0 + + def _shutdown_recv_impl(self): + """Close stdin + """ + self._proc.stdout.close() + + def _shutdown_send_impl(self): + """Close stdout + """ + self._proc.stdin.close() + + def _close_impl(self): + """Close process + """ + self._proc.stdin.close() + self._proc.stdout.close() + if self._is_alive_impl(): + self._proc.kill() + self._proc.wait() + logger.info(f"{str(self)} killed") + else: + logger.info(f"{str(self)} has already exited") + + def _is_alive_impl(self) -> bool: + """Check if the process is alive""" + return self.poll() is None + + def __str__(self) -> str: + return f"'{self._filepath}' (PID={self._proc.pid})" + + + # + # Custom method + # + @tube_is_open + def poll(self) -> Optional[int]: + """Check if the process has exited + """ + if self._proc.poll() is None: + return None + + if self._returncode is None: + # First time to detect process exit + self._returncode = self._proc.returncode + name = signal_name(-self._returncode, detail=True) + if name: + name = '--> ' + name + logger.error(f"{str(self)} stopped with exit code " \ + f"{self._returncode} {name}") + + return self._returncode + + +Process = WinProcess if _is_windows else UnixProcess +process = Process # alias for the Process diff --git a/ptrlib/connection/sock.py b/ptrlib/connection/sock.py index 823cf5b..265080f 100644 --- a/ptrlib/connection/sock.py +++ b/ptrlib/connection/sock.py @@ -4,7 +4,9 @@ from logging import getLogger from typing import Literal, Optional, Union from ptrlib.binary.encoding import * -from .tube import Tube +from .tube import Tube, tube_is_open + +logger = getLogger(__name__) class Socket(Tube): @@ -30,8 +32,6 @@ def __init__(self, Returns: Socket: ``Socket`` instance. """ - super().__init__() - # Interpret host name and port number host = bytes2str(host) if port is None: @@ -74,6 +74,8 @@ def __init__(self, logger.error(f"Connection to {self._host}:{self._port} refused") raise e from None + super().__init__(**kwargs) + # # Implementation of Tube methods # @@ -103,9 +105,20 @@ def _recv_impl(self, size: int) -> bytes: TimeoutError: Timeout exceeded OSError: System error """ + # NOTE: We cannot rely on the blocking behavior of `recv` + # because the socket might be non-blocking mode + # due to `_is_alive_impl` on multi-thread environment. + select.select([self._sock], [], []) + try: data = self._sock.recv(size) + except BlockingIOError: + # NOTE: This exception can occur if this method is called + # while `_is_alive_impl` is running in multi-thread. + # We make `_recv_impl` fail in this case. + return b'' + except socket.timeout: raise TimeoutError("Timeout (_recv_impl)", b'') from None @@ -135,7 +148,7 @@ def _send_impl(self, data: bytes) -> int: data = str2bytes(data) try: - self._sock.send(data) + return self._sock.send(data) except BrokenPipeError as e: logger.error("Broken pipe") @@ -157,6 +170,7 @@ def _close_impl(self): """Close socket """ self._sock.close() + logger.info(f"Connection to {str(self)} closed") def _is_alive_impl(self) -> bool: """Check if socket is alive @@ -181,22 +195,15 @@ def _is_alive_impl(self) -> bool: return ret - def _shutdown_impl(self, target: Literal['send', 'recv']): - """Kill one connection - - Close send/recv socket. - - Args: - target (str): Connection to close (`send` or `recv`) + def _shutdown_recv_impl(self): + """Close read """ - if target in ['write', 'send', 'stdin']: - self._sock.shutdown(socket.SHUT_WR) - - elif target in ['read', 'recv', 'stdout', 'stderr']: - self._sock.shutdown(socket.SHUT_RD) + self._sock.shutdown(socket.SHUT_RD) - else: - raise ValueError("`target` must either 'send' or 'recv'") + def _shutdown_send_impl(self): + """Close write + """ + self._sock.shutdown(socket.SHUT_WR) def __str__(self) -> str: return f"{self._host}:{self._port}" @@ -205,6 +212,7 @@ def __str__(self) -> str: # # Custom methods # + @tube_is_open def set_keepalive(self, keep_idle: Optional[Union[int, float]]=None, keep_interval: Optional[Union[int, float]]=None, @@ -229,4 +237,4 @@ def set_keepalive(self, self._sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, keep_count) -remote = Socket +remote = Socket # alias diff --git a/ptrlib/connection/tube.py b/ptrlib/connection/tube.py index ce28e41..85c9193 100644 --- a/ptrlib/connection/tube.py +++ b/ptrlib/connection/tube.py @@ -5,11 +5,51 @@ import threading from logging import getLogger from typing import List, Literal, Optional, Tuple, Union -from ptrlib.binary.encoding import bytes2str, str2bytes, bytes2hex, bytes2utf8 +from ptrlib.binary.encoding import bytes2str, str2bytes, bytes2hex, bytes2utf8, hexdump from ptrlib.console.color import Color logger = getLogger(__name__) +def tube_is_open(method): + """Ensure that connection is not *explicitly* closed + """ + def decorator(self, *args, **kwargs): + assert isinstance(self, Tube), "Invalid usage of decorator" + if self._is_closed: + raise BrokenPipeError("Connection has already been closed by `close`") + return method(self, *args, **kwargs) + return decorator + +def tube_is_alive(method): + """Ensure that connection is not *implicitly* closed + """ + def decorator(self, *args, **kwargs): + assert isinstance(self, Tube), "Invalid usage of decorator" + if not self.is_alive(): + raise BrokenPipeError("Connection has already been closed by {str(args[0])}") + return method(self, *args, **kwargs) + return decorator + +def tube_is_send_open(method): + """Ensure that sender connection is not explicitly closed + """ + def decorator(self, *args, **kwargs): + assert isinstance(self, Tube), "Invalid usage of decorator" + if self._is_send_closed: + raise BrokenPipeError("Connection has already been closed by `shutdown`") + return method(self, *args, **kwargs) + return decorator + +def tube_is_recv_open(method): + """Ensure that receiver connection is not explicitly closed + """ + def decorator(self, *args, **kwargs): + assert isinstance(self, Tube), "Invalid usage of decorator" + if self._is_recv_closed: + raise BrokenPipeError("Connection has already been closed by `shutdown`") + return method(self, *args, **kwargs) + return decorator + class Tube(metaclass=abc.ABCMeta): """Abstract class for streaming data @@ -21,41 +61,54 @@ class Tube(metaclass=abc.ABCMeta): - "_send_impl" - "_close_impl" - "_is_alive_impl - - "_shutdown_impl" + - "_shutdown_recv_impl" + - "_shutdown_send_impl" """ - # - # Decorator - # - def not_closed(method): - """Ensure that socket is not *explicitly* closed - """ - def decorator(*args, **kwargs): - assert isinstance(args[0], Tube), "Invalid usage of decorator" - if args[0]._is_closed: - raise BrokenPipeError("Socket has already been closed") - return method(*args, **kwargs) - return decorator + def __new__(cls, *args, **kwargs): + cls._settimeout_impl = tube_is_open(cls._settimeout_impl) + cls._recv_impl = tube_is_recv_open(tube_is_open(cls._recv_impl)) + cls._send_impl = tube_is_send_open(tube_is_open(cls._send_impl)) + cls._close_impl = tube_is_open(cls._close_impl) + cls._is_alive_impl = tube_is_open(cls._is_alive_impl) + cls._shutdown_recv_impl = tube_is_recv_open(cls._shutdown_recv_impl) + cls._shutdown_send_impl = tube_is_send_open(cls._shutdown_send_impl) + return super().__new__(cls) # # Constructor # def __init__(self, - timeout: Optional[Union[int, float]]=None): - """ + timeout: Optional[Union[int, float]]=None, + debug: bool=False): + """Base constructor + Args: timeout (float): Default timeout """ self._buffer = b'' + self._debug = debug self._is_closed = False + self._is_send_closed = False + self._is_recv_closed = False self._default_timeout = timeout self.settimeout() + # + # Properties + # + @property + def debug(self): + return self._debug + + @debug.setter + def debug(self, is_debug): + self._debug = bool(is_debug) + # # Methods # - @not_closed def settimeout(self, timeout: Optional[Union[int, float]]=None): """Set timeout @@ -122,10 +175,15 @@ def recv(self, return data if timeout is not None: - self.settimeout(timeout) + self.settimeout(0) try: - self._buffer += self._recv_impl(size - len(self._buffer)) + data = self._recv_impl(size - len(self._buffer)) + if self._debug and len(data) > 0: + logger.info(f"Received {hex(len(data))} ({len(data)}) bytes:") + hexdump(data, prefix=" " + Color.CYAN, postfix=Color.END) + + self._buffer += data except TimeoutError as err: data = self._buffer + err.args[1] @@ -419,7 +477,28 @@ def send(self, data: Union[str, bytes]) -> int: """ assert isinstance(data, (str, bytes)), "`data` must be either str or bytes" - return self._send_impl(str2bytes(data)) + size = self._send_impl(str2bytes(data)) + if self.debug: + logger.info(f"Sent {hex(size)} ({size}) bytes:") + hexdump(data[:size], prefix=Color.YELLOW, postfix=Color.END) + + return size + + def sendall(self, data: Union[str, bytes]): + """Send the whole data + + Send the whole data. + This method will never return until it finishes sending + the whole data, unlike `send`. + + Args: + data: Data to send + """ + to_send = len(data) + while to_send > 0: + sent = self.send(data) + data = data[sent:] + to_send -= sent def sendline(self, data: Union[int, float, str, bytes], @@ -509,26 +588,24 @@ def sendctrl(self, name: str): raise ValueError(f"Invalid control key name: {name}") def sh(self, - timeout: Optional[Union[int, float]]=None, prompt: str="[ptrlib]$ ", raw: bool=False): """Alias for interactive Args: - timeout: Timeout in second - prompt : Prompt string to show on input + prompt: Prompt string to show on input + raw : Escape non-printable characters or not """ - self.interactive(timeout, prompt, raw) + self.interactive(prompt, raw) def interactive(self, - timeout: Union[int, float]=1, prompt: str="[ptrlib]$ ", raw: bool=False): """Interactive mode Args: - timeout: Timeout in second - prompt : Prompt string to show on input + prompt: Prompt string to show on input + raw : Escape non-printable characters or not """ prompt = f"{Color.BOLD}{Color.BLUE}{prompt}{Color.END}" @@ -572,7 +649,7 @@ def thread_recv(flag: threading.Event): try: sys.stdout.write(prompt) sys.stdout.flush() - data = self.recv(timeout=timeout) + data = self.recv() leftover = pretty_print(data, leftover) if not self.is_alive(): @@ -602,6 +679,9 @@ def thread_send(flag: threading.Event): except (ConnectionResetError, ConnectionAbortedError, OSError): flag.set() + # Disable timeout + self.settimeout(0) + flag = threading.Event() th_recv = threading.Thread(target=thread_recv, args=(flag,)) th_send = threading.Thread(target=thread_send, args=(flag,)) @@ -677,7 +757,14 @@ def shutdown(self, target: Literal['send', 'recv']): data = tube.recv() # NG ``` """ - return self._shutdown_impl(target) + if target in ['write', 'send', 'stdin']: + self._shutdown_send_impl() + self._is_send_closed = True + elif target in ['read', 'recv', 'stdout', 'stderr']: + self._shutdown_recv_impl() + self._is_recv_closed = True + else: + raise ValueError("`target` must either 'send' or 'recv'") def __enter__(self): return self @@ -693,7 +780,6 @@ def __str__(self) -> str: # Abstract methods # @abc.abstractmethod - @not_closed def _recv_impl(self, size: int) -> bytes: """Abstract method for `recv` @@ -703,7 +789,6 @@ def _recv_impl(self, size: int) -> bytes: pass @abc.abstractmethod - @not_closed def _send_impl(self, data: bytes) -> int: """Abstract method for `send` @@ -715,7 +800,6 @@ def _send_impl(self, data: bytes) -> int: pass @abc.abstractmethod - @not_closed def _close_impl(self): """Abstract method for `close` @@ -725,7 +809,6 @@ def _close_impl(self): pass @abc.abstractmethod - @not_closed def _is_alive_impl(self) -> bool: """Abstract method for `is_alive` @@ -734,8 +817,13 @@ def _is_alive_impl(self) -> bool: pass @abc.abstractmethod - @not_closed - def _shutdown_impl(self, target: Literal['send', 'recv']): - """Kill one connection + def _shutdown_recv_impl(self): + """Kill receiver connection + """ + pass + + @abc.abstractmethod + def _shutdown_send_impl(self): + """Kill sender connection """ pass From a929ae83931bd2cfc8bb2dd98d31d717d43f4e48 Mon Sep 17 00:00:00 2001 From: Yudai Fujiwara Date: Tue, 23 Apr 2024 15:36:57 +0900 Subject: [PATCH 03/12] Fix timeout bug --- ptrlib/connection/proc.py | 32 ++++++++++++++++++++++++++++++-- ptrlib/connection/sock.py | 16 ++++++++++------ ptrlib/connection/tube.py | 2 +- tests/connection/test_sock.py | 2 +- 4 files changed, 42 insertions(+), 10 deletions(-) diff --git a/ptrlib/connection/proc.py b/ptrlib/connection/proc.py index 3aeb171..79d8fb5 100644 --- a/ptrlib/connection/proc.py +++ b/ptrlib/connection/proc.py @@ -86,11 +86,13 @@ def __init__(self, # Prepare stdio if raw: + # TODO pass else: master, self._slave = pty.openpty() tty.setraw(master) tty.setraw(self._slave) + stdout = self._slave if stdin is None: stdin = subprocess.PIPE if stdout is None: stdout = subprocess.PIPE @@ -115,6 +117,18 @@ def __init__(self, self._returncode = None self._current_timeout = self._default_timeout + # Duplicate master + if not raw and master is not None: + self._proc.stdout = os.fdopen(os.dup(master), 'r+b', 0) + os.close(master) + + # Set in non-blocking mode + fd = self._proc.stdout.fileno() + fl = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK) + + logger.info(f"Successfully created new process {str(self)}") + # # Properties # @@ -140,7 +154,7 @@ def _recv_impl(self, size: int) -> bytes: bytes: The received data """ ready, [], [] = select.select( - [self._proc.stdout], [], [], self._current_timeout + [self._proc.stdout.fileno()], [], [], self._current_timeout ) if len(ready) == 0: raise TimeoutError("Timeout (_recv_impl)", b'') from None @@ -153,7 +167,21 @@ def _recv_impl(self, size: int) -> bytes: return data def _send_impl(self, data: bytes) -> int: - return 0 + """Send raw data + + Raises: + ConnectionAbortedError: Connection is aborted by process + ConnectionResetError: Connection is closed by peer + TimeoutError: Timeout exceeded + OSError: System error + """ + try: + n_written = self._proc.stdin.write(data) + self._proc.stdin.flush() + return n_written + except IOError as err: + logger.error("Broken pipe: {str(self)}") + raise err from None def _shutdown_recv_impl(self): """Close stdin diff --git a/ptrlib/connection/sock.py b/ptrlib/connection/sock.py index 265080f..b8debcb 100644 --- a/ptrlib/connection/sock.py +++ b/ptrlib/connection/sock.py @@ -3,7 +3,7 @@ import socket from logging import getLogger from typing import Literal, Optional, Union -from ptrlib.binary.encoding import * +from ptrlib.binary.encoding import bytes2str from .tube import Tube, tube_is_open logger = getLogger(__name__) @@ -32,6 +32,8 @@ def __init__(self, Returns: Socket: ``Socket`` instance. """ + super().__init__(**kwargs) + # Interpret host name and port number host = bytes2str(host) if port is None: @@ -74,7 +76,7 @@ def __init__(self, logger.error(f"Connection to {self._host}:{self._port} refused") raise e from None - super().__init__(**kwargs) + self._current_timeout = self._default_timeout # # Implementation of Tube methods @@ -86,7 +88,7 @@ def _settimeout_impl(self, Args: timeout: Timeout in second """ - self._sock.settimeout(timeout) + self._current_timeout = timeout def _recv_impl(self, size: int) -> bytes: """Receive raw data @@ -108,7 +110,11 @@ def _recv_impl(self, size: int) -> bytes: # NOTE: We cannot rely on the blocking behavior of `recv` # because the socket might be non-blocking mode # due to `_is_alive_impl` on multi-thread environment. - select.select([self._sock], [], []) + ready, [], [] = select.select( + [self._sock], [], [], self._current_timeout + ) + if len(ready) == 0: + raise TimeoutError("Timeout (_recv_impl)", b'') from None try: data = self._sock.recv(size) @@ -145,8 +151,6 @@ def _send_impl(self, data: bytes) -> int: TimeoutError: Timeout exceeded OSError: System error """ - data = str2bytes(data) - try: return self._sock.send(data) diff --git a/ptrlib/connection/tube.py b/ptrlib/connection/tube.py index 85c9193..80df5c0 100644 --- a/ptrlib/connection/tube.py +++ b/ptrlib/connection/tube.py @@ -175,7 +175,7 @@ def recv(self, return data if timeout is not None: - self.settimeout(0) + self.settimeout(timeout) try: data = self._recv_impl(size - len(self._buffer)) diff --git a/tests/connection/test_sock.py b/tests/connection/test_sock.py index 73cd431..23c036f 100644 --- a/tests/connection/test_sock.py +++ b/tests/connection/test_sock.py @@ -30,7 +30,7 @@ def test_timeout(self): sock.sendline(b'GET / HTTP/1.1\r') sock.send(b'Host: www.example.com\r\n\r\n') try: - sock.recvuntil("*** never expected ***", timeout=1) + sock.recvuntil("*** never expected ***", timeout=2) result = False except TimeoutError as err: self.assertEqual(b"200 OK" in err.args[1], True) From 70b130aa6bc1511386303d1a377bf2cfccbadbee Mon Sep 17 00:00:00 2001 From: Yudai Fujiwara Date: Tue, 23 Apr 2024 16:29:44 +0900 Subject: [PATCH 04/12] Implement process class --- ptrlib/connection/proc.py | 47 ++++++++++++++++++++++++----------- ptrlib/connection/sock.py | 11 +++++--- ptrlib/connection/tube.py | 42 +++++++++++++++++++++++-------- tests/connection/test_proc.py | 6 ++--- tests/connection/test_sock.py | 22 ++++++++-------- 5 files changed, 86 insertions(+), 42 deletions(-) diff --git a/ptrlib/connection/proc.py b/ptrlib/connection/proc.py index 79d8fb5..cb33c9f 100644 --- a/ptrlib/connection/proc.py +++ b/ptrlib/connection/proc.py @@ -85,10 +85,8 @@ def __init__(self, args = list(map(bytes2str, args)) # Prepare stdio - if raw: - # TODO - pass - else: + master = self._slave = None + if not raw: master, self._slave = pty.openpty() tty.setraw(master) tty.setraw(self._slave) @@ -118,7 +116,7 @@ def __init__(self, self._current_timeout = self._default_timeout # Duplicate master - if not raw and master is not None: + if master is not None: self._proc.stdout = os.fdopen(os.dup(master), 'r+b', 0) os.close(master) @@ -153,8 +151,13 @@ def _recv_impl(self, size: int) -> bytes: Returns: bytes: The received data """ + if self._current_timeout == 0: + timeout = None + else: + timeout = self._current_timeout + ready, [], [] = select.select( - [self._proc.stdout.fileno()], [], [], self._current_timeout + [self._proc.stdout.fileno()], [], [], timeout ) if len(ready) == 0: raise TimeoutError("Timeout (_recv_impl)", b'') from None @@ -196,14 +199,16 @@ def _shutdown_send_impl(self): def _close_impl(self): """Close process """ - self._proc.stdin.close() - self._proc.stdout.close() if self._is_alive_impl(): self._proc.kill() self._proc.wait() - logger.info(f"{str(self)} killed") - else: - logger.info(f"{str(self)} has already exited") + logger.info(f"{str(self)} killed by `close`") + + if self._slave is not None: # PTY mode + os.close(self._slave) + + self._proc.stdin.close() + self._proc.stdout.close() def _is_alive_impl(self) -> bool: """Check if the process is alive""" @@ -216,7 +221,6 @@ def __str__(self) -> str: # # Custom method # - @tube_is_open def poll(self) -> Optional[int]: """Check if the process has exited """ @@ -228,12 +232,25 @@ def poll(self) -> Optional[int]: self._returncode = self._proc.returncode name = signal_name(-self._returncode, detail=True) if name: - name = '--> ' + name - logger.error(f"{str(self)} stopped with exit code " \ - f"{self._returncode} {name}") + name = ' --> ' + name + + logger_func = logger.info if self._returncode == 0 else logger.error + logger_func(f"{str(self)} stopped with exit code " \ + f"{self._returncode}{name}") return self._returncode + @tube_is_open + def wait(self, timeout: Optional[Union[int, float]]=None) -> int: + """Wait until the process dies + + Wait until the process exits and get the status code. + + Returns: + code (int): Status code of the process + """ + return self._proc.wait(timeout) + Process = WinProcess if _is_windows else UnixProcess process = Process # alias for the Process diff --git a/ptrlib/connection/sock.py b/ptrlib/connection/sock.py index b8debcb..8ad0104 100644 --- a/ptrlib/connection/sock.py +++ b/ptrlib/connection/sock.py @@ -110,14 +110,19 @@ def _recv_impl(self, size: int) -> bytes: # NOTE: We cannot rely on the blocking behavior of `recv` # because the socket might be non-blocking mode # due to `_is_alive_impl` on multi-thread environment. - ready, [], [] = select.select( - [self._sock], [], [], self._current_timeout - ) + if self._current_timeout == 0: + timeout = None + else: + timeout = self._current_timeout + + ready, [], [] = select.select([self._sock], [], [], timeout) if len(ready) == 0: raise TimeoutError("Timeout (_recv_impl)", b'') from None try: data = self._sock.recv(size) + if len(data) == 0: + raise ConnectionResetError("Empty reply") from None except BlockingIOError: # NOTE: This exception can occur if this method is called diff --git a/ptrlib/connection/tube.py b/ptrlib/connection/tube.py index 80df5c0..ae9d03b 100644 --- a/ptrlib/connection/tube.py +++ b/ptrlib/connection/tube.py @@ -269,7 +269,7 @@ def recvuntil(self, for i, d in enumerate(delim): assert isinstance(d, (str, bytes)), \ f"`delim[{i}]` must be either str or bytes" - delim[i] = str2bytes(delim) + delim[i] = str2bytes(delim[i]) else: delim = [str2bytes(delim)] @@ -282,6 +282,9 @@ def recvuntil(self, data += self.recv(size, timeout) except TimeoutError as err: raise TimeoutError("Timeout (recvuntil)", data + err.args[1]) + except Exception as err: + err.args = (err.args[0], data) + raise err from None for d in delim: if d in data[max(0, prev_len-len(d)):]: @@ -549,16 +552,32 @@ def sendafter(self, tube.sendafter("command: ", 1) # b"1" is sent ``` """ - assert isinstance(data, (int, float, str, bytes)), \ - "`data` must be int, float, str, or bytes" + recv_data = self.recvuntil(delim, size, timeout, drop, lookahead) + self.send(data) - if isinstance(data, (int, float)): - data = str(data).encode() - else: - data = str2bytes(data) + return recv_data + + def sendlineafter(self, + delim: Union[str, bytes], + data: Union[str, bytes, int], + size: int=4096, + timeout: Optional[Union[int, float]]=None, + drop: bool=False, + lookahead: bool=False) -> bytes: + """Send raw data after a delimiter + + Send raw data with newline after `delim` is received. + Args: + delim (bytes): The delimiter + data (bytes) : Data to send + timeout (int): Timeout (in second) + + Returns: + bytes: Received bytes before `delim` comes. + """ recv_data = self.recvuntil(delim, size, timeout, drop, lookahead) - self.send(data) + self.sendline(data, timeout=timeout) return recv_data @@ -679,9 +698,6 @@ def thread_send(flag: threading.Event): except (ConnectionResetError, ConnectionAbortedError, OSError): flag.set() - # Disable timeout - self.settimeout(0) - flag = threading.Event() th_recv = threading.Thread(target=thread_recv, args=(flag,)) th_send = threading.Thread(target=thread_send, args=(flag,)) @@ -776,6 +792,10 @@ def __exit__(self, _e_type, _e_value, _traceback): def __str__(self) -> str: return "" + def __del__(self): + if not self._is_closed: + self.close() + # # Abstract methods # diff --git a/tests/connection/test_proc.py b/tests/connection/test_proc.py index c6d5f3f..be73241 100644 --- a/tests/connection/test_proc.py +++ b/tests/connection/test_proc.py @@ -26,7 +26,7 @@ def test_basic(self): with self.assertLogs(module_name) as cm: p = Process("./tests/test.bin/test_echo.x64") self.assertEqual(len(cm.output), 1) - self.assertRegex(cm.output[0], fr'^INFO:{module_name}:Successfully created new process \(PID=\d+\)$') + self.assertEqual(cm.output[0], f'INFO:{module_name}:Successfully created new process {str(p)}') # sendline / recvline p.sendline(b"Message : " + msg) @@ -70,7 +70,7 @@ def test_basic(self): with self.assertLogs(module_name) as cm: p.close() self.assertEqual(len(cm.output), 1) - self.assertRegex(cm.output[0], fr'^INFO:{module_name}:.+ \(PID=\d+\) has already exited$') + self.assertEqual(cm.output[0], fr'INFO:{module_name}:{str(p)} stopped with exit code {p.poll()}') def test_timeout(self): module_name = inspect.getmodule(Process).__name__ @@ -78,7 +78,7 @@ def test_timeout(self): with self.assertLogs(module_name) as cm: p = Process("./tests/test.bin/test_echo.x64") self.assertEqual(len(cm.output), 1) - self.assertRegex(cm.output[0], fr'^INFO:{module_name}:Successfully created new process \(PID=\d+\)$') + self.assertEqual(cm.output[0], fr'INFO:{module_name}:Successfully created new process {str(p)}') data = os.urandom(16).hex() # recv diff --git a/tests/connection/test_sock.py b/tests/connection/test_sock.py index 23c036f..e5b35ea 100644 --- a/tests/connection/test_sock.py +++ b/tests/connection/test_sock.py @@ -29,18 +29,20 @@ def test_timeout(self): sock = Socket("www.example.com", 80) sock.sendline(b'GET / HTTP/1.1\r') sock.send(b'Host: www.example.com\r\n\r\n') - try: + + with self.assertRaises(TimeoutError) as cm: sock.recvuntil("*** never expected ***", timeout=2) - result = False - except TimeoutError as err: - self.assertEqual(b"200 OK" in err.args[1], True) - result = True - except: - result = False - finally: - sock.close() + self.assertEqual(b"200 OK" in cm.exception.args[1], True) - self.assertEqual(result, True) + def test_reset(self): + sock = Socket("www.example.com", 80) + sock.sendline(b'GET / HTTP/1.1\r') + sock.send(b'Host: www.example.com\r\n') + sock.send(b'Connection: close\r\n\r\n') + + with self.assertRaises(ConnectionResetError) as cm: + sock.recvuntil("*** never expected ***", timeout=2) + self.assertEqual(b"200 OK" in cm.exception.args[1], True) def test_tls(self): host = "www.example.com" From 3ab9806cb5ac8b0a2cba4ccd311fbfda00fb114d Mon Sep 17 00:00:00 2001 From: Yudai Fujiwara Date: Wed, 24 Apr 2024 19:09:55 +0900 Subject: [PATCH 05/12] Implement WinProcess --- ptrlib/connection/proc.py | 7 +- ptrlib/connection/sock.py | 8 +- ptrlib/connection/tube.py | 15 +- ptrlib/connection/winproc.py | 283 +++++++++++++++++++++++++++++++++++ 4 files changed, 305 insertions(+), 8 deletions(-) diff --git a/ptrlib/connection/proc.py b/ptrlib/connection/proc.py index cb33c9f..483092d 100644 --- a/ptrlib/connection/proc.py +++ b/ptrlib/connection/proc.py @@ -4,8 +4,9 @@ from logging import getLogger from typing import List, Literal, Mapping, Optional, Union from ptrlib.arch.linux.sig import signal_name -from ptrlib.binary.encoding import bytes2str, str2bytes +from ptrlib.binary.encoding import bytes2str from .tube import Tube, tube_is_open +from .winproc import WinProcess _is_windows = os.name == 'nt' @@ -63,6 +64,9 @@ def __init__(self, assert cwd is None or isinstance(cwd, (str, bytes)), \ "`cwd` must be either str or bytes" + # NOTE: We need to initialize _current_timeout before super constructor + # because it may call _settimeout_impl + self._current_timeout = 0 super().__init__(**kwargs) # Guess shell mode based on args @@ -113,7 +117,6 @@ def __init__(self, self._filepath = args[0] self._returncode = None - self._current_timeout = self._default_timeout # Duplicate master if master is not None: diff --git a/ptrlib/connection/sock.py b/ptrlib/connection/sock.py index 8ad0104..a3231aa 100644 --- a/ptrlib/connection/sock.py +++ b/ptrlib/connection/sock.py @@ -32,6 +32,12 @@ def __init__(self, Returns: Socket: ``Socket`` instance. """ + assert isinstance(host, (str, bytes)), \ + "`host` must be either str or bytes" + + # NOTE: We need to initialize _current_timeout before super constructor + # because it may call _settimeout_impl + self._current_timeout = 0 super().__init__(**kwargs) # Interpret host name and port number @@ -76,8 +82,6 @@ def __init__(self, logger.error(f"Connection to {self._host}:{self._port} refused") raise e from None - self._current_timeout = self._default_timeout - # # Implementation of Tube methods # diff --git a/ptrlib/connection/tube.py b/ptrlib/connection/tube.py index ae9d03b..5513f58 100644 --- a/ptrlib/connection/tube.py +++ b/ptrlib/connection/tube.py @@ -1,6 +1,5 @@ import abc import re -import select import sys import threading from logging import getLogger @@ -690,9 +689,6 @@ def thread_send(flag: threading.Event): #sys.stdout.write(f"{Color.BOLD}{Color.BLUE}{prompt}{Color.END}") #sys.stdout.flush() while not flag.isSet(): - (ready, _, _) = select.select([sys.stdin], [], [], 0.1) - if not ready: continue - try: self.send(sys.stdin.readline()) except (ConnectionResetError, ConnectionAbortedError, OSError): @@ -799,6 +795,17 @@ def __del__(self): # # Abstract methods # + @abc.abstractmethod + def _settimeout_impl(self, timeout: Union[int, float]): + """Abstract method for `settimeout` + + Set timeout for receive and send. + + Args: + timeout: Timeout in second + """ + pass + @abc.abstractmethod def _recv_impl(self, size: int) -> bytes: """Abstract method for `recv` diff --git a/ptrlib/connection/winproc.py b/ptrlib/connection/winproc.py index e69de29..4c1bc39 100644 --- a/ptrlib/connection/winproc.py +++ b/ptrlib/connection/winproc.py @@ -0,0 +1,283 @@ +from logging import getLogger +from typing import List, Mapping, Optional, Union +import os +import subprocess +from ptrlib.binary.encoding import bytes2str +from .tube import Tube + +_is_windows = os.name == 'nt' +if _is_windows: + import pywintypes + import win32api + import win32con + import win32event + import win32file + import win32pipe + import win32process + import win32security + +logger = getLogger(__name__) + +class WinPipe(object): + def __init__(self, + read: Optional[bool]=False, + write: Optional[bool]=False, + size: Optional[int]=65536): + """Create a pipe for Windows + + Create a new pipe with overlapped I/O. + + Args: + read: True if read mode + write: True if write mode + size: Default buffer size for this pipe + timeout: Default timeout in second + """ + if read and write: + mode = win32pipe.PIPE_ACCESS_DUPLEX + self._access = win32con.GENERIC_READ | win32con.GENERIC_WRITE + elif write: + mode = win32pipe.PIPE_ACCESS_OUTBOUND + self._access = win32con.GENERIC_READ + else: + mode = win32pipe.PIPE_ACCESS_INBOUND + self._access = win32con.GENERIC_WRITE + + self._attr = win32security.SECURITY_ATTRIBUTES() + self._attr.bInheritHandle = True + + self._name = f"\\\\.\\pipe\\ptrlib.{os.getpid()}.{os.urandom(8).hex()}" + self._handle = win32pipe.CreateNamedPipe( + self._name, mode | win32file.FILE_FLAG_OVERLAPPED, + win32pipe.PIPE_TYPE_BYTE | win32pipe.PIPE_READMODE_BYTE | win32pipe.PIPE_WAIT, + 1, size, size, 0, self._attr + ) + assert self._handle != win32file.INVALID_HANDLE_VALUE, \ + "Could not create a pipe" + + @property + def name(self) -> str: + return self._name + + @property + def access(self) -> int: + return self._access + + @property + def attributes(self) -> pywintypes.SECURITY_ATTRIBUTES: + return self._attr + + @property + def handle(self) -> int: + return self._handle + + def close(self): + """Gracefully close this pipe + """ + win32api.CloseHandle(self._handle) + + def __del__(self): + self.close() + +class WinProcess(Tube): + # + # Constructor + # + def __init__(self, + args: Union[List[Union[str, bytes]], str], + env: Optional[Union[Mapping[bytes, Union[bytes, str]], Mapping[str, Union[bytes, str]]]]=None, + cwd: Optional[Union[bytes, str]]=None, + flags: int = 0, + raw: bool=False, + stdin : Optional[WinPipe]=None, + stdout: Optional[WinPipe]=None, + stderr: Optional[WinPipe]=None, + **kwargs): + """Create a Windows process + + Create a Windows process and make a pipe. + + Args: + args : The arguments to pass + env : The environment variables + cwd : Working directory + flags : dwCreationFlags passed to CreateProcess + raw : Disable pty if this parameter is true + stdin : File descriptor of standard input + stdout : File descriptor of standard output + stderr : File descriptor of standard error + + Returns: + WinProcess: ``WinProcess`` instance + + Examples: + ``` + p = Process("cmd.exe", cwd="C:\\") + p = Process(["cmd", "dir"], + stderr=subprocess.DEVNULL) + p = Process("more C:\\test.txt", env={"X": "123"}) + ``` + """ + assert _is_windows, "WinProcess cannot work on Unix" + assert isinstance(args, (str, bytes, list)), \ + "`args` must be either str, bytes, or list" + assert env is None or isinstance(env, dict), \ + "`env` must be a dictionary" + assert cwd is None or isinstance(cwd, (str, bytes)), \ + "`cwd` must be either str or bytes" + + self._current_timeout = 0 + super().__init__(**kwargs) + + if isinstance(args, list): + for i, arg in enumerate(args): + if isinstance(arg, bytes): + args[i] = bytes2str(arg) + args = subprocess.list2cmdline(args) + + else: + args = bytes2str(args) + + self._filepath = args + + # Prepare stdio + if stdin is None: + self._stdin = WinPipe(write=True) + proc_stdin = win32file.CreateFile( + self._stdin.name, self._stdin.access, + 0, self._stdin.attributes, + win32con.OPEN_EXISTING, win32file.FILE_ATTRIBUTE_NORMAL, None + ) + + if stdout is None: + self._stdout = WinPipe(read=True) + proc_stdout = win32file.CreateFile( + self._stdout.name, self._stdout.access, + 0, self._stdout.attributes, + win32con.OPEN_EXISTING, win32file.FILE_ATTRIBUTE_NORMAL, None + ) + + if stderr is None: + self._stderr = self._stdout + proc_stderr = proc_stdout + else: + proc_stderr = win32file.CreateFile( + self._stderr.name, self._stderr.access, + 0, self._stderr.attributes, + win32con.OPEN_EXISTING, win32file.FILE_ATTRIBUTE_NORMAL, None + ) + + # Open process + info = win32process.STARTUPINFO() + info.dwFlags = win32con.STARTF_USESTDHANDLES + info.hStdInput = proc_stdin + info.hStdOutput = proc_stdout + info.hStdError = proc_stderr + self._proc, _, self._pid, _ = win32process.CreateProcess( + None, args, None, None, True, flags, env, cwd, info + ) + + win32file.CloseHandle(proc_stdin) + win32file.CloseHandle(proc_stdout) + if proc_stdout != proc_stderr: + win32file.CloseHandle(proc_stderr) + + # Wait until connection + win32pipe.ConnectNamedPipe(self._stdin.handle) + win32pipe.ConnectNamedPipe(self._stdout.handle) + win32pipe.ConnectNamedPipe(self._stderr.handle) + + logger.info(f"Successfully created new process {str(self)}") + + # + # Implementation of Tube + # + def _settimeout_impl(self, timeout: Union[int, float]): + """Set timeout + + Args: + timeout: Timeout in second (Maximum precision is millisecond) + """ + self._current_timeout = timeout + + def _recv_impl(self, size: int) -> bytes: + """Receive raw data + + Args: + size: Size to receive + + Returns: + bytes: Received data + """ + if self._current_timeout == 0: + # Without timeout + try: + _, data = win32file.ReadFile(self._stdout.handle, size) + return data + except Exception as err: + raise err from None + + else: + # With timeout + overlapped = pywintypes.OVERLAPPED() + overlapped.hEvent = win32event.CreateEvent(None, 0, 0, None) + try: + _, data = win32file.ReadFile(self._stdout.handle, size, overlapped) + state = win32event.WaitForSingleObject( + overlapped.hEvent, int(self._current_timeout * 1000) + ) + if state == win32event.WAIT_OBJECT_0: + result = win32file.GetOverlappedResult(self._stdout.handle, overlapped, True) + if isinstance(result, int): + # NOTE: GetOverlappedResult does not return data + # when overlapped ReadFile is successful. + # We need to use the result of this API because + # we cannot access the number of bytes read by ReadFile. + # See https://github.com/mhammond/pywin32/issues/430 + return data[:result] + else: + return result[1] + else: + raise TimeoutError("Timeout (_recv_impl)", b'') + finally: + win32file.CloseHandle(overlapped.hEvent) + + def _send_impl(self, data: bytes) -> int: + """Send raw data + + Args: + data: Data to send + + Returns: + int: The number of bytes written + """ + _, n = win32file.WriteFile(self._stdin.handle, data) + return n + + def _close_impl(self): + win32api.TerminateProcess(self._proc, 0) + win32api.CloseHandle(self._proc) + logger.info(f"Process killed {str(self)}") + + def _is_alive_impl(self) -> bool: + """Check if process is alive + + Returns: + bool: True if process is alive, otherwise False + """ + status = win32process.GetExitCodeProcess(self._proc) + return status == win32con.STILL_ACTIVE + + def _shutdown_recv_impl(self): + """Kill receiver connection + """ + self._stdin.close() + + def _shutdown_send_impl(self): + """Kill sender connection + """ + self._stdout.close() + + def __str__(self) -> str: + return f'{self._filepath} (PID={self._pid})' + \ No newline at end of file From e8d6e911ffb3b21e8fb60f22b35b465f6959987c Mon Sep 17 00:00:00 2001 From: Yudai Fujiwara Date: Wed, 24 Apr 2024 19:12:50 +0900 Subject: [PATCH 06/12] Fix tests --- ptrlib/connection/proc.py | 4 +++- ptrlib/connection/sock.py | 2 +- ptrlib/connection/winproc.py | 18 ++++++++++++++++++ tests/connection/test_proc.py | 2 +- tests/pwn/test_fsb.py | 6 +++--- 5 files changed, 26 insertions(+), 6 deletions(-) diff --git a/ptrlib/connection/proc.py b/ptrlib/connection/proc.py index cb33c9f..7dbbce2 100644 --- a/ptrlib/connection/proc.py +++ b/ptrlib/connection/proc.py @@ -2,7 +2,7 @@ import select import subprocess from logging import getLogger -from typing import List, Literal, Mapping, Optional, Union +from typing import List, Mapping, Optional, Union from ptrlib.arch.linux.sig import signal_name from ptrlib.binary.encoding import bytes2str, str2bytes from .tube import Tube, tube_is_open @@ -190,6 +190,7 @@ def _shutdown_recv_impl(self): """Close stdin """ self._proc.stdout.close() + self._proc.stderr.close() def _shutdown_send_impl(self): """Close stdout @@ -209,6 +210,7 @@ def _close_impl(self): self._proc.stdin.close() self._proc.stdout.close() + self._proc.stderr.close() def _is_alive_impl(self) -> bool: """Check if the process is alive""" diff --git a/ptrlib/connection/sock.py b/ptrlib/connection/sock.py index 8ad0104..25420f3 100644 --- a/ptrlib/connection/sock.py +++ b/ptrlib/connection/sock.py @@ -2,7 +2,7 @@ import select import socket from logging import getLogger -from typing import Literal, Optional, Union +from typing import Optional, Union from ptrlib.binary.encoding import bytes2str from .tube import Tube, tube_is_open diff --git a/ptrlib/connection/winproc.py b/ptrlib/connection/winproc.py index e69de29..299e5ca 100644 --- a/ptrlib/connection/winproc.py +++ b/ptrlib/connection/winproc.py @@ -0,0 +1,18 @@ +from logging import getLogger +from typing import List, Mapping +from ptrlib.binary.encoding import str2bytes +from .tube import Tube +import ctypes +import os +import time + +_is_windows = os.name == 'nt' +if _is_windows: + import win32api + import win32con + import win32file + import win32pipe + import win32process + import win32security + +logger = getLogger(__name__) diff --git a/tests/connection/test_proc.py b/tests/connection/test_proc.py index be73241..e4d017c 100644 --- a/tests/connection/test_proc.py +++ b/tests/connection/test_proc.py @@ -70,7 +70,7 @@ def test_basic(self): with self.assertLogs(module_name) as cm: p.close() self.assertEqual(len(cm.output), 1) - self.assertEqual(cm.output[0], fr'INFO:{module_name}:{str(p)} stopped with exit code {p.poll()}') + self.assertEqual(cm.output[0], fr'INFO:{module_name}:{str(p)} stopped with exit code 0') def test_timeout(self): module_name = inspect.getmodule(Process).__name__ diff --git a/tests/pwn/test_fsb.py b/tests/pwn/test_fsb.py index a4fcf30..d38d505 100644 --- a/tests/pwn/test_fsb.py +++ b/tests/pwn/test_fsb.py @@ -23,7 +23,7 @@ def test_fsb32(self): with self.assertLogs(module_name) as cm: p = Process("./tests/test.bin/test_fsb.x86") self.assertEqual(len(cm.output), 1) - self.assertRegex(cm.output[0], fr'^INFO:{module_name}:Successfully created new process \(PID=\d+\)$') + self.assertEqual(cm.output[0], f'INFO:{module_name}:Successfully created new process {str(p)}') p.recvuntil(": ") target = int(p.recvline(), 16) payload = fsb( @@ -42,7 +42,7 @@ def test_fsb32(self): with self.assertLogs(module_name) as cm: p = Process("./tests/test.bin/test_fsb.x86") self.assertEqual(len(cm.output), 1) - self.assertRegex(cm.output[0], fr'^INFO:{module_name}:Successfully created new process \(PID=\d+\)$') + self.assertEqual(cm.output[0], f'INFO:{module_name}:Successfully created new process {str(p)}') p.recvuntil(": ") target = int(p.recvline(), 16) payload = fsb( @@ -65,7 +65,7 @@ def test_fsb64(self): with self.assertLogs(module_name) as cm: p = Process("./tests/test.bin/test_fsb.x64") self.assertEqual(len(cm.output), 1) - self.assertRegex(cm.output[0], fr'^INFO:{module_name}:Successfully created new process \(PID=\d+\)$') + self.assertEqual(cm.output[0], f'INFO:{module_name}:Successfully created new process {str(p)}') p.recvuntil(": ") target = int(p.recvline(), 16) payload = fsb( From 85464df0eb0fedfe2d7920e93b5c7a1b57f6daf8 Mon Sep 17 00:00:00 2001 From: Yudai Fujiwara Date: Wed, 24 Apr 2024 19:27:13 +0900 Subject: [PATCH 07/12] Fix resource leak --- ptrlib/connection/__init__.py | 5 ++--- ptrlib/connection/proc.py | 12 ++++++++---- ptrlib/connection/winproc.py | 2 +- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/ptrlib/connection/__init__.py b/ptrlib/connection/__init__.py index 8706d43..beb4a72 100644 --- a/ptrlib/connection/__init__.py +++ b/ptrlib/connection/__init__.py @@ -1,4 +1,3 @@ -from .proc import * -from .sock import * +from .proc import Process, process +from .sock import Socket, remote from .ssh import * -from .winproc import * diff --git a/ptrlib/connection/proc.py b/ptrlib/connection/proc.py index d244b0d..bd7cdf3 100644 --- a/ptrlib/connection/proc.py +++ b/ptrlib/connection/proc.py @@ -210,10 +210,14 @@ def _close_impl(self): if self._slave is not None: # PTY mode os.close(self._slave) - - self._proc.stdin.close() - self._proc.stdout.close() - self._proc.stderr.close() + self._slave = None + + if self._proc.stdin is not None: + self._proc.stdin.close() + if self._proc.stdout is not None: + self._proc.stdout.close() + if self._proc.stderr is not None: + self._proc.stderr.close() def _is_alive_impl(self) -> bool: """Check if the process is alive""" diff --git a/ptrlib/connection/winproc.py b/ptrlib/connection/winproc.py index a27c480..5a03c0a 100644 --- a/ptrlib/connection/winproc.py +++ b/ptrlib/connection/winproc.py @@ -64,7 +64,7 @@ def access(self) -> int: return self._access @property - def attributes(self) -> pywintypes.SECURITY_ATTRIBUTES: + def attributes(self) -> any: return self._attr @property From 48f4999fc182576759080d926ee925783d36e5f4 Mon Sep 17 00:00:00 2001 From: Yudai Fujiwara Date: Wed, 24 Apr 2024 19:42:11 +0900 Subject: [PATCH 08/12] Fix for testcase --- ptrlib/connection/proc.py | 4 ++++ ptrlib/connection/tube.py | 12 ++---------- ptrlib/connection/winproc.py | 23 ++++++++++++++++++++--- tests/connection/test_windows_proc.py | 4 ++-- 4 files changed, 28 insertions(+), 15 deletions(-) diff --git a/ptrlib/connection/proc.py b/ptrlib/connection/proc.py index bd7cdf3..53b20ae 100644 --- a/ptrlib/connection/proc.py +++ b/ptrlib/connection/proc.py @@ -136,6 +136,10 @@ def __init__(self, @property def returncode(self) -> Optional[int]: return self._returncode + + @property + def pid(self) -> int: + return self._proc.pid # # Implementation of Tube methods diff --git a/ptrlib/connection/tube.py b/ptrlib/connection/tube.py index 5513f58..5215b93 100644 --- a/ptrlib/connection/tube.py +++ b/ptrlib/connection/tube.py @@ -19,16 +19,6 @@ def decorator(self, *args, **kwargs): return method(self, *args, **kwargs) return decorator -def tube_is_alive(method): - """Ensure that connection is not *implicitly* closed - """ - def decorator(self, *args, **kwargs): - assert isinstance(self, Tube), "Invalid usage of decorator" - if not self.is_alive(): - raise BrokenPipeError("Connection has already been closed by {str(args[0])}") - return method(self, *args, **kwargs) - return decorator - def tube_is_send_open(method): """Ensure that sender connection is not explicitly closed """ @@ -746,6 +736,8 @@ def is_alive(self) -> bool: print(tube.recv()) ``` """ + if self._is_closed: + return False return self._is_alive_impl() def shutdown(self, target: Literal['send', 'recv']): diff --git a/ptrlib/connection/winproc.py b/ptrlib/connection/winproc.py index 5a03c0a..390d480 100644 --- a/ptrlib/connection/winproc.py +++ b/ptrlib/connection/winproc.py @@ -187,8 +187,21 @@ def __init__(self, win32pipe.ConnectNamedPipe(self._stdout.handle) win32pipe.ConnectNamedPipe(self._stderr.handle) + self._returncode = None + logger.info(f"Successfully created new process {str(self)}") + # + # Property + # + @property + def returncode(self) -> Optional[int]: + return self._returncode + + @property + def pid(self) -> int: + return self._pid + # # Implementation of Tube # @@ -266,17 +279,21 @@ def _is_alive_impl(self) -> bool: bool: True if process is alive, otherwise False """ status = win32process.GetExitCodeProcess(self._proc) - return status == win32con.STILL_ACTIVE + if status == win32con.STILL_ACTIVE: + return True + else: + self._returncode = status + return False def _shutdown_recv_impl(self): """Kill receiver connection """ - self._stdin.close() + self._stdout.close() def _shutdown_send_impl(self): """Kill sender connection """ - self._stdout.close() + self._stdin.close() def __str__(self) -> str: return f'{self._filepath} (PID={self._pid})' diff --git a/tests/connection/test_windows_proc.py b/tests/connection/test_windows_proc.py index 5282a3f..66ce000 100644 --- a/tests/connection/test_windows_proc.py +++ b/tests/connection/test_windows_proc.py @@ -27,7 +27,7 @@ def test_basic(self): with self.assertLogs(module_name) as cm: p = Process("./tests/test.bin/test_echo.pe.exe") self.assertEqual(len(cm.output), 1) - self.assertRegex(cm.output[0], fr'^INFO:{module_name}:Successfully created new process \(PID=\d+\)$') + self.assertEqual(cm.output[0], f'INFO:{module_name}:Successfully created new process {str(p)}') pid = p.pid # send / recv @@ -58,7 +58,7 @@ def test_timeout(self): with self.assertLogs(module_name) as cm: p = Process("./tests/test.bin/test_echo.pe.exe") self.assertEqual(len(cm.output), 1) - self.assertRegex(cm.output[0], fr'^INFO:{module_name}:Successfully created new process \(PID=\d+\)$') + self.assertEqual(cm.output[0], f'INFO:{module_name}:Successfully created new process {str(p)}') with self.assertRaises(TimeoutError): p.recvuntil("*** never expected ***", timeout=1) From 4fa5b791a115687a9edaa54f9386431ac85a4c8a Mon Sep 17 00:00:00 2001 From: Yudai Fujiwara Date: Wed, 24 Apr 2024 19:50:14 +0900 Subject: [PATCH 09/12] Refactor for SSH --- ptrlib/connection/proc.py | 2 +- ptrlib/connection/ssh.py | 67 +++++++++++++++++++++++++++++++++++++++ ptrlib/connection/tube.py | 3 +- 3 files changed, 69 insertions(+), 3 deletions(-) diff --git a/ptrlib/connection/proc.py b/ptrlib/connection/proc.py index 53b20ae..5a2c906 100644 --- a/ptrlib/connection/proc.py +++ b/ptrlib/connection/proc.py @@ -190,7 +190,7 @@ def _send_impl(self, data: bytes) -> int: self._proc.stdin.flush() return n_written except IOError as err: - logger.error("Broken pipe: {str(self)}") + logger.error(f"Broken pipe: {str(self)}") raise err from None def _shutdown_recv_impl(self): diff --git a/ptrlib/connection/ssh.py b/ptrlib/connection/ssh.py index e69de29..3749087 100644 --- a/ptrlib/connection/ssh.py +++ b/ptrlib/connection/ssh.py @@ -0,0 +1,67 @@ +import shlex +import os +from ptrlib.binary.encoding import * +from ptrlib.arch.common import which +from .proc import * + +_is_windows = os.name == 'nt' + + +def SSH(host: str, + port: int, + username: str, + password: Optional[str]=None, + identity: Optional[str]=None, + ssh_path: Optional[str]=None, + expect_path: Optional[str]=None, + option: str='', + command: str=''): + """Create an SSH shell + + Create a new process to connect to SSH server + + Args: + host (str) : SSH hostname + port (int) : SSH port + username (str): SSH username + password (str): SSH password + identity (str): Path of identity file + option (str) : Parameters to pass to SSH + command (str) : Initial command to execute on remote + + Returns: + Process: ``Process`` instance. + """ + assert isinstance(port, int) + if password is None and identity is None: + raise ValueError("You must give either password or identity") + + if ssh_path is None: + ssh_path = which('ssh') + if expect_path is None: + expect_path = which('expect') + + if not os.path.isfile(ssh_path): + raise FileNotFoundError("{}: SSH not found".format(ssh_path)) + if not os.path.isfile(expect_path): + raise FileNotFoundError("{}: 'expect' not found".format(expect_path)) + + if identity is not None: + option += ' -i {}'.format(shlex.quote(identity)) + + script = 'eval spawn {} -oStrictHostKeyChecking=no -oCheckHostIP=no {}@{} -p{} {} {}; interact; lassign [wait] pid spawnid err value; exit "$value"'.format( + ssh_path, + shlex.quote(username), + shlex.quote(host), + port, + option, + command + ) + + proc = Process( + [expect_path, '-c', script], + ) + if identity is None: + proc.sendlineafter("password: ", password) + + return proc diff --git a/ptrlib/connection/tube.py b/ptrlib/connection/tube.py index 5215b93..1db2651 100644 --- a/ptrlib/connection/tube.py +++ b/ptrlib/connection/tube.py @@ -638,8 +638,7 @@ def pretty_print(data: bytes, prev: bytes=b''): if t: if 0x7f <= ord(c) < 0x100: pretty_print_hex(c) - elif ord(c) not in [0x09, 0x0a, 0x0d] and \ - ord(c) < 0x20: + elif ord(c) in [0x00]: # TODO: What is printable? pretty_print_hex(c) else: sys.stdout.write(c) From 4c80898a7265ba69699e5639de7c74808c797526 Mon Sep 17 00:00:00 2001 From: Yudai Fujiwara Date: Sat, 11 May 2024 15:06:30 +0900 Subject: [PATCH 10/12] Update ansi --- ptrlib/binary/encoding/ansi.py | 566 ++++++++++++++++++++++----------- ptrlib/connection/proc.py | 1 + ptrlib/connection/sock.py | 2 + ptrlib/connection/tube.py | 19 +- 4 files changed, 389 insertions(+), 199 deletions(-) diff --git a/ptrlib/binary/encoding/ansi.py b/ptrlib/binary/encoding/ansi.py index 11d1ebd..f74eb6b 100644 --- a/ptrlib/binary/encoding/ansi.py +++ b/ptrlib/binary/encoding/ansi.py @@ -1,193 +1,375 @@ -import functools -import re - -try: - cache = functools.cache -except AttributeError: - cache = functools.lru_cache - - -@cache -def _escape_codes(): - codes = {} - # Cursor - codes['CSI_CURSOR_MOVE'] = re.compile(rb'^\x1b\[([1-9]\d*);([1-9]\d*)[Hf]') - codes['CSI_CURSOR_ROW'] = re.compile(rb'^\x1b\[([1-9]\d*)d') - codes['CSI_CURSOR_COLUMN'] = re.compile(rb'^\x1b\[([1-9]\d*)[`G]') - codes['CSI_CURSOR_UP'] = re.compile(rb'^\x1b\[(\d*)A') - codes['CSI_CURSOR_DOWN'] = re.compile(rb'^\x1b\[(\d*)B') - codes['CSI_CURSOR_RIGHT'] = re.compile(rb'^\x1b\[(\d*)C') - codes['CSI_CURSOR_LEFT'] = re.compile(rb'^\x1b\[(\d*)D') - codes['CSI_CURSOR_UP_HEAD'] = re.compile(rb'^\x1b\[(\d*)F') - codes['CSI_CURSOR_DOWN_HEAD'] = re.compile(rb'^\x1b\[(\d*)E') - codes['CSI_CURSOR_SAVE'] = re.compile(rb'^\x1b\[s') - codes['CSI_CURSOR_RESTORE'] = re.compile(rb'^\x1b\[u') - codes['CSI_CURSOR_REQUEST'] = re.compile(rb'^\x1b\[6n') - codes['FP_CURSOR_SAVE'] = re.compile(rb'^\x1b7') - codes['FP_CURSOR_RESTORE'] = re.compile(rb'^\x1b8') - codes['FE_CURSOR_ONEUP'] = re.compile(rb'^\x1bM') - - # Character - codes['CSI_CHAR_REPEAT'] = re.compile(rb'^\x1b\[(\d+)b') - - # Erase - codes['CSI_ERASE_DISPLAY_FORWARD'] = re.compile(rb'^\x1b\[[0]J') - codes['CSI_ERASE_DISPLAY_BACKWARD'] = re.compile(rb'^\x1b\[1J') - codes['CSI_ERASE_DISPLAY_ALL'] = re.compile(rb'^\x1b\[2J') - codes['CSI_ERASE_LINE_FORWARD'] = re.compile(rb'^\x1b\[[0]K') - codes['CSI_ERASE_LINE_BACKWARD'] = re.compile(rb'^\x1b\[1K') - codes['CSI_ERASE_LINE_ALL'] = re.compile(rb'^\x1b\[2K') - - # Others - codes['CSI_COLOR'] = re.compile(rb'^\x1b\[(\d+)m') - codes['CSI_MODE'] = re.compile(rb'^\x1b\[=(\d+)[hl]') - codes['CSI_PRIVATE_MODE'] = re.compile(rb'^\x1b\[?(\d+)[hl]') - - return codes - - -def draw_ansi(buf: bytes): - """Interpret ANSI code sequences to screen - - Args: - buf (bytes): ANSI code sequences - - Returns: - list: 2D array of screen to be drawn - """ - draw = [] - E = _escape_codes() - width = height = x = y = 0 - saved_dec = saved_sco = None - while len(buf): - if buf[0] == 13: # \r - x = 0 - buf = buf[1:] - continue - - elif buf[0] == 10: # \n - x = 0 - y += 1 - buf = buf[1:] - continue - - elif buf[0] != 0x1b: - if x >= width: width = x + 1 - if y >= height: height = y + 1 - draw.append(('PUTCHAR', x, y, buf[0])) - x += 1 - buf = buf[1:] - continue - - # CSI sequences - if m := E['CSI_CURSOR_MOVE'].match(buf): - y, x = int(m.group(1)) - 1, int(m.group(2)) - 1 - elif m := E['CSI_CURSOR_ROW'].match(buf): - y = int(m.group(1)) - 1 - elif m := E['CSI_CURSOR_COLUMN'].match(buf): - x = int(m.group(1)) - 1 - elif m := E['CSI_CURSOR_UP'].match(buf): - y = max(0, y - int(m.group(1))) if m.group(1) else max(0, y-1) - elif m := E['CSI_CURSOR_DOWN'].match(buf): - y += int(m.group(1)) if m.group(1) else 1 - elif m := E['CSI_CURSOR_LEFT'].match(buf): - x = max(0, x - int(m.group(1))) if m.group(1) else max(0, x-1) - elif m := E['CSI_CURSOR_RIGHT'].match(buf): - x += int(m.group(1)) if m.group(1) else 1 - elif m := E['CSI_CURSOR_UP_HEAD'].match(buf): - x, y = 0, max(0, y - int(m.group(1))) if m.group(1) else max(0, y-1) - elif m := E['CSI_CURSOR_DOWN_HEAD'].match(buf): - x, y = 0, y + int(m.group(1)) if m.group(1) else y+1 - elif m := E['CSI_CURSOR_SAVE'].match(buf): - saved_sco = (x, y) - elif m := E['CSI_CURSOR_RESTORE'].match(buf): - if saved_sco is not None: x, y = saved_sco - elif m := E['CSI_CURSOR_REQUEST'].match(buf): - pass # Not implemented: Request cursor position - elif m := E['CSI_COLOR'].match(buf): - pass # Not implemented: Change color - elif m := E['CSI_MODE'].match(buf): - pass # Not implemented: Set mode - - # Repease character - elif m := E['CSI_CHAR_REPEAT'].match(buf): - n = int(m.group(1)) - draw.append(('CSI_CHAR_REPEAT', x, y, n)) - x += n - - # Fe escape sequences - elif m := E['FE_CURSOR_ONEUP'].match(buf): - y = max(0, y - 1) # scroll not implemented - - # Fp escape sequences - elif m := E['FP_CURSOR_SAVE'].match(buf): - saved_dec = (x, y) - elif m := E['FP_CURSOR_RESTORE'].match(buf): - if saved_dec is not None: x, y = saved_dec - - # Operation - else: - for k in ['CSI_ERASE_DISPLAY_FORWARD', - 'CSI_ERASE_DISPLAY_BACKWARD', - 'CSI_ERASE_DISPLAY_ALL', - 'CSI_ERASE_LINE_FORWARD', - 'CSI_ERASE_LINE_BACKWARD', - 'CSI_ERASE_LINE_ALL']: - if m := E[k].match(buf): - if k == 'CSI_ERASE_DISPLAY_ALL': - draw = [] - else: - draw.append((k, x, y, None)) - break - - # Otherwise draw text - if m: - buf = buf[m.end():] +import enum +from typing import Generator, List, Optional + +# Based on https://bjh21.me.uk/all-escapes/all-escapes.txt + +class AnsiOp(enum.Enum): + UNKNOWN = 0 + + # C0 Control Sequence + BEL = 0x10 + BS = enum.auto() + HT = enum.auto() + LF = enum.auto() + FF = enum.auto() + CR = enum.auto() + ESC = enum.auto() + + # Fe Escape Sequence + BPH = 0x20 # Break permitted here + NBH = enum.auto() # No break here + IND = enum.auto() # Index + NEL = enum.auto() # Next line + SSA = enum.auto() # Start of selected area + ESA = enum.auto() # End of selected area + HTS = enum.auto() # Character tabulation set + HTJ = enum.auto() # Character tabulation with justification + VTS = enum.auto() # Line tabulation set + PLD = enum.auto() # Partial line forward + PLU = enum.auto() # Partial line backward + RI = enum.auto() # Reverse line feed + SS2 = enum.auto() # Single-shift two + SS3 = enum.auto() # Single-shift three + DCS = enum.auto() # Device control string + PU1 = enum.auto() # Private use one + PU2 = enum.auto() # Private use two + STS = enum.auto() # Set transmit state + CCH = enum.auto() # Cancel character + MW = enum.auto() # Message waiting + SPA = enum.auto() # Start of guarded area + EPA = enum.auto() # End of guarded area + SOS = enum.auto() # Start of string + SCI = enum.auto() # Single character introducer + CSI = enum.auto() # Control sequence + + # CSI Sequence + ICH = 0x100 # Insert character + SBC = enum.auto() # Set border color + CUU = enum.auto() # Cursor up + SBP = enum.auto() # Set bell parameters + CUD = enum.auto() # Cursor down + SCR = enum.auto() # Set cursor parameters + CUF = enum.auto() # Cursor right + SBI = enum.auto() # Set background intensity + CUB = enum.auto() # Cursor left + SBB = enum.auto() # Set background blink bit + CNL = enum.auto() # Cursor next line + SNF = enum.auto() # Set normal foreground color + CPL = enum.auto() # Cursor preceding line + SNB = enum.auto() # Set normal background color + CHA = enum.auto() # Cursor character absolute + SRF = enum.auto() # Set reverse foreground color + CUP = enum.auto() # Cursor position + SRB = enum.auto() # Set reverse background color + CHT = enum.auto() # Cursor forward tabulation + ED = enum.auto() # Erase in page + SGF = enum.auto() # Set graphic foreground color + EL = enum.auto() # Erase in line + SGB = enum.auto() # Set graphic background color + IL = enum.auto() # Insert line + SEF = enum.auto() # Set emulator feature + DL = enum.auto() # Delete line + RAS = enum.auto() # Return attribute setting + EF = enum.auto() # Erase in field + EA = enum.auto() # Erase in area + DCH = enum.auto() # Delete character + SEE = enum.auto() # Select editing extent + CPR = enum.auto() # Active position report + SU = enum.auto() # Scroll up + SD = enum.auto() # Scroll down + NP = enum.auto() # Next page + PP = enum.auto() # Preceding page + CTC = enum.auto() # Cursor tabulation control + ECH = enum.auto() # Erase character + CVT = enum.auto() # Cursor line tabulation + CBT = enum.auto() # Cursor backward tabulation + SRS = enum.auto() # Start reversed string + PTX = enum.auto() # Parallel texts + SDS = enum.auto() # Start directed string + SIMD = enum.auto() # Select implicit movement direction + + +class AnsiInstruction(object): + def __init__(self, + c0: AnsiOp, + code: Optional[AnsiOp]=None, + args: Optional[List[int]]=None): + self._c0 = c0 + self._code = code + self._args = args + + @property + def args(self): + return self._args + + def __str__(self): + return f'' + +class AnsiParser(object): + CTRL = [0x1b, 0x07, 0x08, 0x09, 0x0a, 0x0c, 0x0d] + ESC, BEL, BS, HT, LF, FF, CR = CTRL + + def __init__(self, + generator: Generator[bytes, None, None]): + """ + Args: + generator: A generator which yields byte stream + """ + self._g = generator + self._buffer = b'' + + def _decode_csi(self) -> Optional[AnsiInstruction]: + """Decode a CSI sequence + """ + c0, code = AnsiOp.ESC, AnsiOp.CSI + + # Parse parameters + mode_set = 0 + cur = 2 + args = [] + + if cur < len(self._buffer) and self._buffer[cur] == ord('='): + mode_set = 1 + cur += 1 + + while True: + prev = cur + while cur < len(self._buffer) and 0x30 <= self._buffer[cur] <= 0x39: + cur += 1 + + if cur >= len(self._buffer): + return None + + # NOTE: Common implementation seems to skip successive delimiters + if cur != prev: + args.append(int(self._buffer[prev:cur])) + + if self._buffer[cur] == ord(';'): + cur += 1 + else: + break + + # Check mnemonic + if self._buffer[cur] == ord('@'): + code = AnsiOp.ICH + default = (1,) + elif self._buffer[cur] == ord('A'): + code = [AnsiOp.CUU, AnsiOp.SBC][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('B'): + code = [AnsiOp.CUD, AnsiOp.SBP][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('C'): + code = [AnsiOp.CUF, AnsiOp.SCR][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('D'): + code = [AnsiOp.CUB, AnsiOp.SBI][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('E'): + code = [AnsiOp.CNL, AnsiOp.SBB][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('F'): + code = [AnsiOp.CPL, AnsiOp.SNF][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('G'): + code = [AnsiOp.CHA, AnsiOp.SNB][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('H'): + code = [AnsiOp.CUP, AnsiOp.SRF][mode_set] + default = [(1,1), ()][mode_set] + elif self._buffer[cur] == ord('I'): + # TODO: Support screen saver off + code = [AnsiOp.CHT, AnsiOp.SRB][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('J'): + # TODO: Support DECSED and screen saver on + code = [AnsiOp.ED, AnsiOp.SGF][mode_set] + default = [(0,), ()][mode_set] + elif self._buffer[cur] == ord('K'): + # TODO: Support DECSEL + code = [AnsiOp.EL, AnsiOp.SGB][mode_set] + default = [(0,), ()][mode_set] + elif self._buffer[cur] == ord('L'): + code = [AnsiOp.IL, AnsiOp.SEF][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('M'): + code = [AnsiOp.DL, AnsiOp.RAS][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('N'): + code, default = AnsiOp.EF, (0,) + default = (0,) + elif self._buffer[cur] == ord('O'): + code, default = AnsiOp.EA, (0,) + elif self._buffer[cur] == ord('P'): + code, default = AnsiOp.DCH, (1,) + elif self._buffer[cur] == ord('Q'): + code, default = AnsiOp.SEE, (0,) + elif self._buffer[cur] == ord('R'): + # TODO: Support DECXCPR + code, default = AnsiOp.CPR, (1, 1) + elif self._buffer[cur] == ord('S'): + code, default = AnsiOp.SU, (1,) + elif self._buffer[cur] == ord('T'): + # TODO: Support initiate hilite mouse tracking + code, default = AnsiOp.SD, (1,) + elif self._buffer[cur] == ord('U'): + code, default = AnsiOp.NP, (1,) + elif self._buffer[cur] == ord('V'): + code, default = AnsiOp.PP, (1,) + elif self._buffer[cur] == ord('W'): + # TODO: Support DECST8C + code, default = AnsiOp.CTC, (0,) + elif self._buffer[cur] == ord('X'): + code, default = AnsiOp.ECH, (1,) + elif self._buffer[cur] == ord('Y'): + code, default = AnsiOp.CVT, (1,) + elif self._buffer[cur] == ord('Z'): + code, default = AnsiOp.CBT, (1,) + elif self._buffer[cur] == ord('['): + # TODO: Support ignore next character + code, default = AnsiOp.SRS, (0,) + elif self._buffer[cur] == ord('\\'): + code, default = AnsiOp.PTX, (0,) + elif self._buffer[cur] == ord(']'): + # TODO: Support linux private sequences + code, default = AnsiOp.SDS, (0,) + elif self._buffer[cur] == ord('^'): + code, default = AnsiOp.SIMD, (0,) + + + self._buffer = self._buffer[cur+1:] + return AnsiInstruction(c0, code, args) + + def _decode_esc(self) -> Optional[AnsiInstruction]: + """Decode an ESC sequence + """ + if len(self._buffer) < 2: + return None + + c0 = AnsiOp.ESC + code = AnsiOp.UNKNOWN + if self._buffer[1] == ord('B'): + code = AnsiOp.BPH + elif self._buffer[1] == ord('C'): + code = AnsiOp.NBH + elif self._buffer[1] == ord('D'): + code = AnsiOp.IND + elif self._buffer[1] == ord('E'): + code = AnsiOp.NEL + elif self._buffer[1] == ord('F'): + code = AnsiOp.SSA + elif self._buffer[1] == ord('G'): + code = AnsiOp.ESA + elif self._buffer[1] == ord('H'): + code = AnsiOp.HTS + elif self._buffer[1] == ord('I'): + code = AnsiOp.HTJ + elif self._buffer[1] == ord('J'): + code = AnsiOp.VTS + elif self._buffer[1] == ord('K'): + code = AnsiOp.PLD + elif self._buffer[1] == ord('L'): + code = AnsiOp.PLU + elif self._buffer[1] == ord('M'): + code = AnsiOp.RI + elif self._buffer[1] == ord('N'): + code = AnsiOp.SS2 + elif self._buffer[1] == ord('O'): + code = AnsiOp.SS3 + elif self._buffer[1] == ord('P'): + code = AnsiOp.DCS + elif self._buffer[1] == ord('Q'): + code = AnsiOp.PU1 + elif self._buffer[1] == ord('R'): + code = AnsiOp.PU2 + elif self._buffer[1] == ord('S'): + code = AnsiOp.STS + elif self._buffer[1] == ord('T'): + code = AnsiOp.CCH + elif self._buffer[1] == ord('U'): + code = AnsiOp.MW + elif self._buffer[1] == ord('V'): + code = AnsiOp.SPA + elif self._buffer[1] == ord('W'): + code = AnsiOp.EPA + elif self._buffer[1] == ord('X'): + code = AnsiOp.SOS + elif self._buffer[1] == ord('Z'): + code = AnsiOp.SCI + elif self._buffer[1] == ord('['): + return self._decode_csi() + + return AnsiInstruction(c0, code) + + """ + elif self._buffer[1] == 0x5c: + code = AnsiOp.ST + elif self._buffer[1] == 0x5d: + code = AnsiOp.OSC + elif self._buffer[1] == 0x5e: + code = AnsiOp.PM + elif self._buffer[1] == 0x5f: + code = AnsiOp.APC + """ + + def parse_block(self) -> Optional[AnsiInstruction]: + """Parse a block of ANSI escape sequence + + Returns: + AnsiInstruction: Instruction, or None if need more data + + Raises: + StopIteration: No more data to receive + """ + try: + self._buffer += next(self._g) + except StopIteration: + pass + while len(self._buffer) == 0: + self._buffer += next(self._g) + + # TODO: Support C1 control code + if self._buffer[0] not in AnsiParser.CTRL: + # Return until a control code appears + for i, c in enumerate(self._buffer): + if c in AnsiParser.CTRL: + data, self._buffer = self._buffer[:i], self._buffer[i:] + return data + + data, self._buffer = self._buffer, b'' + return data + + # Check C0 control sequence + if self._buffer[0] == AnsiParser.BEL: # BEL + instr = AnsiInstruction(AnsiOp.BEL) + elif self._buffer[0] == AnsiParser.BS: # BS + instr = AnsiInstruction(AnsiOp.BS) + elif self._buffer[0] == AnsiParser.HT: # HT + instr = AnsiInstruction(AnsiOp.HT) + elif self._buffer[0] == AnsiParser.LF: # LF + instr = AnsiInstruction(AnsiOp.LF) + elif self._buffer[0] == AnsiParser.FF: # FF + instr = AnsiInstruction(AnsiOp.FF) + elif self._buffer[0] == AnsiParser.CR: # CR + instr = AnsiInstruction(AnsiOp.CR) else: - # TODO: skip ESC only? - raise NotImplementedError(f"Could not interpret code: {buf[:10]}") - - # Emualte drawing - screen = [[' ' for x in range(width)] for y in range(height)] - last_char = ' ' - for op, x, y, attr in draw: - if op == 'PUTCHAR': - last_char = chr(attr) - screen[y][x] = last_char - - elif op == 'CSI_CHAR_REPEAT': - for j in range(attr): - screen[y][x+j] = last_char - - elif op == 'CSI_ERASE_DISPLAY_FORWARD': - for j in range(x, width): - screen[y][j] = ' ' - for i in range(y+1, height): - for j in range(width): - screen[i][j] = ' ' - - elif op == 'CSI_ERASE_DISPLAY_BACKWARD': - for j in range(x): - screen[y][j] = ' ' - for i in range(y): - for j in range(width): - screen[i][j] = ' ' - - elif op == 'CSI_ERASE_DISPLAY_ALL': - for i in range(height): - for j in range(width): - screen[i][j] = ' ' - - elif op == 'CSI_ERASE_LINE_FORWARD': - for j in range(x, width): - screen[y][j] = ' ' - - elif op == 'CSI_ERASE_LINE_BACKWARD': - for j in range(x): - screen[y][j] = ' ' - - elif op == 'CSI_ERASE_LINE_ALL': - for j in range(width): - screen[y][j] = ' ' - - return screen + return self._decode_esc() + + self._buffer = self._buffer[1:] + return instr + +if __name__ == '__main__': + def test(): + yield b"ABC\n\x1b[12;23H\x08\x1b[30" + yield b"m\x1b[47mHello" + + ansi = AnsiParser(test()) + print(ansi.parse_block()) + print(ansi.parse_block()) + print(ansi.parse_block()) + print(ansi.parse_block()) + print(ansi.parse_block()) + print(ansi.parse_block()) + print(ansi.parse_block()) + print(ansi.parse_block()) + diff --git a/ptrlib/connection/proc.py b/ptrlib/connection/proc.py index 5a2c906..2a69c2a 100644 --- a/ptrlib/connection/proc.py +++ b/ptrlib/connection/proc.py @@ -129,6 +129,7 @@ def __init__(self, fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK) logger.info(f"Successfully created new process {str(self)}") + self._init_done = True # # Properties diff --git a/ptrlib/connection/sock.py b/ptrlib/connection/sock.py index e97282c..5386587 100644 --- a/ptrlib/connection/sock.py +++ b/ptrlib/connection/sock.py @@ -82,6 +82,8 @@ def __init__(self, logger.error(f"Connection to {self._host}:{self._port} refused") raise e from None + self._init_done = True + # # Implementation of Tube methods # diff --git a/ptrlib/connection/tube.py b/ptrlib/connection/tube.py index 1db2651..f64ad7e 100644 --- a/ptrlib/connection/tube.py +++ b/ptrlib/connection/tube.py @@ -4,7 +4,7 @@ import threading from logging import getLogger from typing import List, Literal, Optional, Tuple, Union -from ptrlib.binary.encoding import bytes2str, str2bytes, bytes2hex, bytes2utf8, hexdump +from ptrlib.binary.encoding import bytes2str, str2bytes, bytes2hex, bytes2utf8, hexdump, draw_ansi from ptrlib.console.color import Color logger = getLogger(__name__) @@ -262,6 +262,9 @@ def recvuntil(self, else: delim = [str2bytes(delim)] + if any(map(lambda d: len(d) == 0, delim)): + return b'' # Empty delimiter + # Iterate until we find one of the delimiters found_delim = None prev_len = 0 @@ -403,17 +406,17 @@ def recvregex(self, def recvscreen(self, delim: Optional[Union[str, bytes]]=b'\x1b[H', returns: Optional[type]=str, - timeout: Optional[Union[int, float]]=None, - timeout2: Optional[Union[int, float]]=1): + prev: Optional[Union[str, bytes, list]]=None, + timeout: Optional[Union[int, float]]=None): """Receive a screen - Receive a screen drawn by ncurses + Receive a screen drawn by ncurses (ANSI escape sequence) Args: delim : Refresh sequence returns : Return value as string or list - timeout : Timeout to receive the first delimiter - timeout2: Timeout to receive the second delimiter + prev : Previous screen (Use when screen is partially updated) + timeout : Timeout until receiving the delimiter Returns: str: Rectangle string drawing the screen @@ -426,6 +429,8 @@ def recvscreen(self, """ assert returns in [list, str, bytes], \ "`returns` must be either list, str, or bytes" + assert prev is None or isinstance(prev, (str, bytes, list)), \ + "`prev` must be either list, str, or bytes" try: self.recvuntil(delim, timeout=timeout) @@ -780,7 +785,7 @@ def __str__(self) -> str: return "" def __del__(self): - if not self._is_closed: + if hasattr(self, '_init_done') and not self._is_closed: self.close() # From 9fcd0ba4ee45540972980f0cc2023e439fa01cf2 Mon Sep 17 00:00:00 2001 From: Yudai Fujiwara Date: Sat, 11 May 2024 23:21:09 +0900 Subject: [PATCH 11/12] Improved ANSI interpreter --- ptrlib/binary/encoding/__init__.py | 2 +- ptrlib/binary/encoding/ansi.py | 514 ++++++++++++++---- ptrlib/binary/encoding/{locale.py => char.py} | 0 ptrlib/connection/tube.py | 60 +- 4 files changed, 442 insertions(+), 134 deletions(-) rename ptrlib/binary/encoding/{locale.py => char.py} (100%) diff --git a/ptrlib/binary/encoding/__init__.py b/ptrlib/binary/encoding/__init__.py index 2fb0c06..d77b38f 100644 --- a/ptrlib/binary/encoding/__init__.py +++ b/ptrlib/binary/encoding/__init__.py @@ -1,6 +1,6 @@ from .ansi import * from .bitconv import * from .byteconv import * +from .char import * from .dump import * -from .locale import * from .table import * diff --git a/ptrlib/binary/encoding/ansi.py b/ptrlib/binary/encoding/ansi.py index f74eb6b..41179fd 100644 --- a/ptrlib/binary/encoding/ansi.py +++ b/ptrlib/binary/encoding/ansi.py @@ -1,8 +1,11 @@ import enum -from typing import Generator, List, Optional +from logging import getLogger +from typing import Callable, Generator, List, Optional, Tuple, Union + +logger = getLogger(__name__) -# Based on https://bjh21.me.uk/all-escapes/all-escapes.txt +# Based on https://bjh21.me.uk/all-escapes/all-escapes.txt class AnsiOp(enum.Enum): UNKNOWN = 0 @@ -42,6 +45,9 @@ class AnsiOp(enum.Enum): SCI = enum.auto() # Single character introducer CSI = enum.auto() # Control sequence + # Fp Private Control Functions + DECKPAM = 0x80 + # CSI Sequence ICH = 0x100 # Insert character SBC = enum.auto() # Set border color @@ -87,7 +93,43 @@ class AnsiOp(enum.Enum): PTX = enum.auto() # Parallel texts SDS = enum.auto() # Start directed string SIMD = enum.auto() # Select implicit movement direction - + HPA = enum.auto() # Character position absolute + HPR = enum.auto() # Character position forward + REP = enum.auto() # Repeat + DA = enum.auto() # Device attributes + HSC = enum.auto() # Hide or show cursor + VPA = enum.auto() # Line position absolute + VPR = enum.auto() # Line position forward + HVP = enum.auto() # Character and line position + TBC = enum.auto() # Tabulation clear + PRC = enum.auto() # Print ROM character + SM = enum.auto() # Set mode + MC = enum.auto() # Media copy + HPB = enum.auto() # Character position backward + VPB = enum.auto() # Line position backward + RM = enum.auto() # Reset mode + CHC = enum.auto() # Clear and home cursor + SGR = enum.auto() # Select graphic rendition + SSM = enum.auto() # Set specific margin + DSR = enum.auto() # Device status report + DAQ = enum.auto() # Device area qualification + DECSSL = enum.auto() # Select set-up language + DECLL = enum.auto() # Load LEDs + DECSTBM = enum.auto() # Set top and bottom margins + RSM = enum.auto() # Reset margins + SCP = enum.auto() # Save cursor position + DECSLPP = enum.auto() # Set lines per physical page + RCP = enum.auto() # Reset cursor position + DECSVTS = enum.auto() # Set vertical tab stops + DECSHORP = enum.auto() # Set horizontal pitch + DGRTC = enum.auto() # Request terminal configuration + DECTST = enum.auto() # Invoke confidence test + SSW = enum.auto() # Screen switch + CAT = enum.auto() # Clear all tabs + + # SCS: Select character set + SCS_B = 0x200 # Default charset + SCS_0 = enum.auto() # DEC special charset class AnsiInstruction(object): def __init__(self, @@ -99,8 +141,40 @@ def __init__(self, self._args = args @property - def args(self): - return self._args + def is_skip(self): + """Check if instruction can be skipped + + Returns: + bool: True if this instruction is not important for drawing screen + """ + return self._code in [ + AnsiOp.DECKPAM, + AnsiOp.DECSLPP, + AnsiOp.DECSTBM, + AnsiOp.SGR, + ] + + def __getitem__(self, i: int): + assert isinstance(i, int), "Slice must be integer" + if i < 0 or i >= len(self._args): + return None + else: + return self._args[i] + + def __eq__(self, other): + if isinstance(other, AnsiInstruction): + return self._c0 == other._c0 and \ + self._code == other._code and \ + self._args == other._args + + elif isinstance(other, AnsiOp): + return self._c0 == other or self._code == other + + else: + raise TypeError(f"Cannot compare AnsiInstruction and {type(other)}") + + def __neq__(self, other): + return not self.__eq__(other) def __str__(self): return f'' @@ -110,13 +184,32 @@ class AnsiParser(object): ESC, BEL, BS, HT, LF, FF, CR = CTRL def __init__(self, - generator: Generator[bytes, None, None]): + generator: Generator[bytes, None, None], + size: Tuple[int, int]=(0, 0), + pos: Tuple[int, int]=(0, 0)): """ Args: generator: A generator which yields byte stream + size: Initial screen size (width, height) + pos: Initial cursor position (x, y) """ self._g = generator self._buffer = b'' + self._width, self._height = size + self._x, self._y = pos + self._last_size = 0 + + @property + def buffer(self) -> bytes: + """Return contents of current buffering + """ + return self._buffer + + def _experimantal_warning(self, message: str): + logger.error(message) + logger.error("This feature is experimental and does not support some ANSI codes.\n" \ + "If you encounter this error, please create an issue here:\n" \ + "https://github.com/ptr-yudai/ptrlib/issues") def _decode_csi(self) -> Optional[AnsiInstruction]: """Decode a CSI sequence @@ -124,12 +217,19 @@ def _decode_csi(self) -> Optional[AnsiInstruction]: c0, code = AnsiOp.ESC, AnsiOp.CSI # Parse parameters - mode_set = 0 + mode_set, mode_q, mode_private = 0, 0, 0 cur = 2 args = [] - if cur < len(self._buffer) and self._buffer[cur] == ord('='): - mode_set = 1 + while cur < len(self._buffer) and self._buffer[cur] in [ord('='), ord('?'), ord('>')]: + if self._buffer[cur] == ord('='): + mode_set = 1 + elif self._buffer[cur] == ord('?'): + mode_q = 1 + elif self._buffer[cur] == ord('>'): # TODO: Is this correct? + mode_private = 1 + else: + raise NotImplementedError("BUG: Unreachable path") cur += 1 while True: @@ -138,6 +238,7 @@ def _decode_csi(self) -> Optional[AnsiInstruction]: cur += 1 if cur >= len(self._buffer): + self._last_size = len(self._buffer) return None # NOTE: Common implementation seems to skip successive delimiters @@ -235,7 +336,78 @@ def _decode_csi(self) -> Optional[AnsiInstruction]: code, default = AnsiOp.SDS, (0,) elif self._buffer[cur] == ord('^'): code, default = AnsiOp.SIMD, (0,) - + elif self._buffer[cur] == ord('`'): + code, default = AnsiOp.HPA, (1,) + elif self._buffer[cur] == ord('a'): + code, default = AnsiOp.HPR, (1,) + elif self._buffer[cur] == ord('b'): + code, default = AnsiOp.REP, (1,) + elif self._buffer[cur] == ord('c'): + # NOTE: This operation has a lot of meanings + code = [AnsiOp.DA, AnsiOp.HSC][mode_set] + default = [(0,), ()][mode_set] + elif self._buffer[cur] == ord('d'): + code, default = AnsiOp.VPA, (1,) + elif self._buffer[cur] == ord('e'): + code, default = AnsiOp.VPR, (1,) + elif self._buffer[cur] == ord('f'): + code, default = AnsiOp.HVP, (1, 1) + elif self._buffer[cur] == ord('g'): + # TODO: Support reset tabs + code = [AnsiOp.TBC, AnsiOp.PRC][mode_set] + default = [(0,), ()][mode_set] + elif self._buffer[cur] == ord('h'): + code, default = AnsiOp.SM, () + elif self._buffer[cur] == ord('i'): + code, default = AnsiOp.MC, () + elif self._buffer[cur] == ord('j'): + code, default = AnsiOp.HPB, (1,) + elif self._buffer[cur] == ord('k'): + code, default = AnsiOp.VPB, (1,) + elif self._buffer[cur] == ord('l'): + # TODO: Support insert line up + code = [AnsiOp.RM, AnsiOp.CHC][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('m'): + # TODO: Support delete line down + code = [AnsiOp.SGR, AnsiOp.SSM][mode_set] + default = [(0,), ()][mode_set] + elif self._buffer[cur] == ord('n'): + code, default = AnsiOp.DSR, (0,) + elif self._buffer[cur] == ord('o'): + code, default = AnsiOp.DAQ, (0,) + elif self._buffer[cur] == ord('p'): + code, default = AnsiOp.DECSSL, () + elif self._buffer[cur] == ord('q'): + code, default = AnsiOp.DECLL, () + elif self._buffer[cur] == ord('r'): + # TODO: Support CSR and SUNSCRL + code = [AnsiOp.DECSTBM, AnsiOp.RSM][mode_set] + default = [(), ()][mode_set] + elif self._buffer[cur] == ord('s'): + code, default = AnsiOp.SCP, () + elif self._buffer[cur] == ord('t'): + code, default = AnsiOp.DECSLPP, () + elif self._buffer[cur] == ord('u'): + code, default = AnsiOp.RCP, () + elif self._buffer[cur] == ord('v'): + code, default = AnsiOp.DECSVTS, () + elif self._buffer[cur] == ord('w'): + code, default = AnsiOp.DECSHORP, () + elif self._buffer[cur] == ord('x'): + code, default = AnsiOp.DGRTC, () + elif self._buffer[cur] == ord('y'): + code, default = AnsiOp.DECTST, () + elif self._buffer[cur] == ord('z'): + # TODO: Support + code = [AnsiOp.SSW, AnsiOp.CAT][mode_set] + default = [(), ()][mode_set] + else: + self._experimantal_warning(f"CSI not implemented: {self._buffer[cur-2:cur+0x10]}") + raise NotImplementedError("Unknown CSI") + + if len(args) < len(default): + args = tuple(args + list(default[len(args):])) self._buffer = self._buffer[cur+1:] return AnsiInstruction(c0, code, args) @@ -243,76 +415,92 @@ def _decode_csi(self) -> Optional[AnsiInstruction]: def _decode_esc(self) -> Optional[AnsiInstruction]: """Decode an ESC sequence """ - if len(self._buffer) < 2: - return None - c0 = AnsiOp.ESC code = AnsiOp.UNKNOWN - if self._buffer[1] == ord('B'): - code = AnsiOp.BPH - elif self._buffer[1] == ord('C'): - code = AnsiOp.NBH - elif self._buffer[1] == ord('D'): - code = AnsiOp.IND - elif self._buffer[1] == ord('E'): - code = AnsiOp.NEL - elif self._buffer[1] == ord('F'): - code = AnsiOp.SSA - elif self._buffer[1] == ord('G'): - code = AnsiOp.ESA - elif self._buffer[1] == ord('H'): - code = AnsiOp.HTS - elif self._buffer[1] == ord('I'): - code = AnsiOp.HTJ - elif self._buffer[1] == ord('J'): - code = AnsiOp.VTS - elif self._buffer[1] == ord('K'): - code = AnsiOp.PLD - elif self._buffer[1] == ord('L'): - code = AnsiOp.PLU - elif self._buffer[1] == ord('M'): - code = AnsiOp.RI - elif self._buffer[1] == ord('N'): - code = AnsiOp.SS2 - elif self._buffer[1] == ord('O'): - code = AnsiOp.SS3 - elif self._buffer[1] == ord('P'): - code = AnsiOp.DCS - elif self._buffer[1] == ord('Q'): - code = AnsiOp.PU1 - elif self._buffer[1] == ord('R'): - code = AnsiOp.PU2 - elif self._buffer[1] == ord('S'): - code = AnsiOp.STS - elif self._buffer[1] == ord('T'): - code = AnsiOp.CCH - elif self._buffer[1] == ord('U'): - code = AnsiOp.MW - elif self._buffer[1] == ord('V'): - code = AnsiOp.SPA - elif self._buffer[1] == ord('W'): - code = AnsiOp.EPA - elif self._buffer[1] == ord('X'): - code = AnsiOp.SOS - elif self._buffer[1] == ord('Z'): - code = AnsiOp.SCI - elif self._buffer[1] == ord('['): - return self._decode_csi() - return AnsiInstruction(c0, code) + cur = 1 + if len(self._buffer) <= cur: + self._last_size = len(self._buffer) + return None - """ - elif self._buffer[1] == 0x5c: - code = AnsiOp.ST - elif self._buffer[1] == 0x5d: - code = AnsiOp.OSC - elif self._buffer[1] == 0x5e: - code = AnsiOp.PM - elif self._buffer[1] == 0x5f: - code = AnsiOp.APC - """ + if self._buffer[cur] == ord('['): + cur += 1 + if self._buffer[cur] == ord('B'): + code = AnsiOp.BPH + elif self._buffer[cur] == ord('C'): + code = AnsiOp.NBH + elif self._buffer[cur] == ord('D'): + code = AnsiOp.IND + elif self._buffer[cur] == ord('E'): + code = AnsiOp.NEL + elif self._buffer[cur] == ord('F'): + code = AnsiOp.SSA + elif self._buffer[cur] == ord('G'): + code = AnsiOp.ESA + elif self._buffer[cur] == ord('H'): + code = AnsiOp.HTS + elif self._buffer[cur] == ord('I'): + code = AnsiOp.HTJ + elif self._buffer[cur] == ord('J'): + code = AnsiOp.VTS + elif self._buffer[cur] == ord('K'): + code = AnsiOp.PLD + elif self._buffer[cur] == ord('L'): + code = AnsiOp.PLU + elif self._buffer[cur] == ord('M'): + code = AnsiOp.RI + elif self._buffer[cur] == ord('N'): + code = AnsiOp.SS2 + elif self._buffer[cur] == ord('O'): + code = AnsiOp.SS3 + elif self._buffer[cur] == ord('P'): + code = AnsiOp.DCS + elif self._buffer[cur] == ord('Q'): + code = AnsiOp.PU1 + elif self._buffer[cur] == ord('R'): + code = AnsiOp.PU2 + elif self._buffer[cur] == ord('S'): + code = AnsiOp.STS + elif self._buffer[cur] == ord('T'): + code = AnsiOp.CCH + elif self._buffer[cur] == ord('U'): + code = AnsiOp.MW + elif self._buffer[cur] == ord('V'): + code = AnsiOp.SPA + elif self._buffer[cur] == ord('W'): + code = AnsiOp.EPA + elif self._buffer[cur] == ord('X'): + code = AnsiOp.SOS + elif self._buffer[cur] == ord('Z'): + code = AnsiOp.SCI + else: + return self._decode_csi() + + elif self._buffer[cur] == ord('('): + cur += 1 + if len(self._buffer) <= cur: + self._last_size = len(self._buffer) + return None + + if self._buffer[cur] == ord('B'): + code = AnsiOp.SCS_B + elif self._buffer[cur] == ord('0'): + code = AnsiOp.SCS_0 + else: + self._experimantal_warning(f"ESC not implemented: {self._buffer[cur-2:cur+0x10]}") + raise NotImplementedError(f"Unknown ESC") + + elif self._buffer[cur] == ord('='): + code = AnsiOp.DECKPAM + + else: + self._experimantal_warning(f"ESC not implemented: {self._buffer[cur-2:cur+0x10]}") + raise NotImplementedError(f"Unknown ESC") + + self._buffer = self._buffer[cur+1:] + return AnsiInstruction(c0, code) - def parse_block(self) -> Optional[AnsiInstruction]: + def parse_block(self) -> Optional[Union[bytes, AnsiInstruction]]: """Parse a block of ANSI escape sequence Returns: @@ -321,12 +509,15 @@ def parse_block(self) -> Optional[AnsiInstruction]: Raises: StopIteration: No more data to receive """ - try: - self._buffer += next(self._g) - except StopIteration: - pass - while len(self._buffer) == 0: - self._buffer += next(self._g) + if len(self._buffer) <= self._last_size: + try: + self._buffer += next(self._g) + except StopIteration as e: + if len(self._buffer) == 0: + # All processed, end of input + raise e from None + + self._last_size = 0 # TODO: Support C1 control code if self._buffer[0] not in AnsiParser.CTRL: @@ -358,18 +549,149 @@ def parse_block(self) -> Optional[AnsiInstruction]: self._buffer = self._buffer[1:] return instr -if __name__ == '__main__': - def test(): - yield b"ABC\n\x1b[12;23H\x08\x1b[30" - yield b"m\x1b[47mHello" - - ansi = AnsiParser(test()) - print(ansi.parse_block()) - print(ansi.parse_block()) - print(ansi.parse_block()) - print(ansi.parse_block()) - print(ansi.parse_block()) - print(ansi.parse_block()) - print(ansi.parse_block()) - print(ansi.parse_block()) - + def _update_screen_size(self, screen): + if len(screen) == 0: + return + self._width = max(map(lambda pos: pos[0], screen.keys())) + 1 + self._height = max(map(lambda pos: pos[1], screen.keys())) + 1 + + def _special_char(self, charset: AnsiOp, c: int): + if charset == AnsiOp.SCS_B: + return c + + elif charset == AnsiOp.SCS_0: + if 0x5f <= c <= 0x7e: + return [0x20, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x6f, + 0x2b, 0x3f, 0x3f, 0x2b, 0x2b, 0x2b, 0x2b, 0x2b, + 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x2b, 0x2b, 0x2b, + 0x2b, 0x7c, 0x3c, 0x3e, 0x6e, 0x3d, 0x66, 0x2e][c - 0x5f] + else: + return c + + else: + self._experimantal_warning(f"Character set not implemented: {charset}") + raise NotImplementedError("Unknown character set") + + def draw_screen(self, + returns: type=list, + stop: Optional[Callable[[AnsiInstruction], bool]]=None) -> list: + """Receive a screen + + Args: + returns: Either str or list + stop: Function to determine when to stop emulating instructions + """ + if stop is None: + # Default stop checker designed for ncurses games + stop = lambda instr: instr == AnsiOp.HTS + + # NOTE: These variables are global so that we can support + # successive draws in the future + self._width = self._height = 0 + + screen = {} + charset = AnsiOp.SCS_B + DEL = 0x20 # Empty + last_char = DEL + stop_recv = False + while not stop_recv: + instr = None + try: + while instr is None: + instr = self.parse_block() + except StopIteration: + break + + stop_recv = stop(instr) + + if isinstance(instr, bytes): + # TODO: Reverse order? + for c in instr: + screen[(self._x, self._y)] = self._special_char(charset, c) + self._x += 1 + last_char = c + + else: + if instr.is_skip: + continue + + elif instr == AnsiOp.SCS_B: # English mode + charset = AnsiOp.SCS_B + + elif instr == AnsiOp.BS: # Back space + self._x = max(0, self._x - 1) + stop_recv = True + + elif instr == AnsiOp.CHA: # Cursor character absolute + self._x = instr[0] - 1 + + elif instr == AnsiOp.SCS_0: # DEC special graphic + charset = AnsiOp.SCS_0 + + elif instr == AnsiOp.CR: # Carriage return + self._x, self._y = 0, self._y + 1 + + elif instr == AnsiOp.CUP: # Cursor position + self._x, self._y = instr[1] - 1, instr[0] - 1 + + elif instr == AnsiOp.ECH: # Erase character + for x in range(self._x, self._x + instr[0]): + screen[(x, self._y)] = DEL + + elif instr == AnsiOp.ED: # Erase in page + self._update_screen_size(screen) + if instr[0] == 0: + for y in range(self._y, self._height): + screen[(self._x, y)] = DEL + elif instr[0] == 1: + for y in range(self._y + 1): + screen[(self._x, y)] = DEL + elif instr[0] == 2: + for y in range(self._height): + screen[(self._x, y)] = DEL + + elif instr == AnsiOp.EL: # Erase in line + self._update_screen_size(screen) + if instr[0] == 0: + for x in range(self._x, self._width): + screen[(x, self._y)] = DEL + elif instr[0] == 1: + for x in range(self._x + 1): + screen[(x, self._y)] = DEL + elif instr[0] == 2: + for x in range(self._width): + screen[(x, self._y)] = DEL + + elif instr == AnsiOp.HTS: + self._x, self._y = 0, 0 + + elif instr == AnsiOp.LF: + self._x, self._y = 0, self._y + 1 + + elif instr == AnsiOp.REP: # Repeat + for x in range(self._x, self._x + instr[0]): + screen[(x, self._y)] = self._special_char(charset, last_char) + self._x += instr[0] + + elif instr == AnsiOp.RM: # Reset mode + pass # TODO: ? + + elif instr == AnsiOp.SM: # Set mode + pass # TODO: ? + + elif instr == AnsiOp.VPA: # Line position absolute + self._y = instr[0] - 1 + + else: + raise ValueError(f"Emulation not supported for instruction {instr}") + + self._update_screen_size(screen) + field = [[' ' for x in range(self._width)] + for y in range(self._height)] + for (x, y) in screen: + field[y][x] = chr(screen[(x, y)]) + + if returns == list: + return field + else: + return '\n'.join(map(lambda line: ''.join(line), field)) diff --git a/ptrlib/binary/encoding/locale.py b/ptrlib/binary/encoding/char.py similarity index 100% rename from ptrlib/binary/encoding/locale.py rename to ptrlib/binary/encoding/char.py diff --git a/ptrlib/connection/tube.py b/ptrlib/connection/tube.py index f64ad7e..36804f9 100644 --- a/ptrlib/connection/tube.py +++ b/ptrlib/connection/tube.py @@ -3,8 +3,8 @@ import sys import threading from logging import getLogger -from typing import List, Literal, Optional, Tuple, Union -from ptrlib.binary.encoding import bytes2str, str2bytes, bytes2hex, bytes2utf8, hexdump, draw_ansi +from typing import Callable, List, Literal, Optional, Tuple, Union +from ptrlib.binary.encoding import bytes2str, str2bytes, bytes2hex, bytes2utf8, hexdump, AnsiParser, AnsiInstruction from ptrlib.console.color import Color logger = getLogger(__name__) @@ -404,52 +404,38 @@ def recvregex(self, return match.group() def recvscreen(self, - delim: Optional[Union[str, bytes]]=b'\x1b[H', - returns: Optional[type]=str, - prev: Optional[Union[str, bytes, list]]=None, - timeout: Optional[Union[int, float]]=None): + returns: type=str, + stop: Optional[Callable[[AnsiInstruction], bool]]=None, + timeout: Union[int, float]=1.0): """Receive a screen Receive a screen drawn by ncurses (ANSI escape sequence) Args: - delim : Refresh sequence - returns : Return value as string or list - prev : Previous screen (Use when screen is partially updated) - timeout : Timeout until receiving the delimiter + returns: Either str or list + stop: Function to determine when to stop emulating instructions + timeout: Timeout until stopping recv Returns: str: Rectangle string drawing the screen - - Raises: - ConnectionAbortedError: Connection is aborted by process - ConnectionResetError: Connection is closed by peer - TimeoutError: Timeout exceeded - OSError: System error """ assert returns in [list, str, bytes], \ - "`returns` must be either list, str, or bytes" - assert prev is None or isinstance(prev, (str, bytes, list)), \ - "`prev` must be either list, str, or bytes" - - try: - self.recvuntil(delim, timeout=timeout) - except TimeoutError as err: - # NOTE: We do not set received value here - raise TimeoutError("Timeout (recvscreen)", b'') + "`returns` must be either list or str" - try: - buf = self.recvuntil(delim, drop=True, lookahead=True, timeout=timeout2) - except TimeoutError as err: - buf = err.args[1] + def _ansi_stream(self): + """Generator for recvscreen + """ + while True: + try: + yield self.recv(timeout=timeout) + except TimeoutError as e: + self.unget(e.args[1]) + break - screen = draw_ansi(buf) - if returns == str: - return '\n'.join(map(lambda row: ''.join(row), screen)) - elif returns == bytes: - return b'\n'.join(map(lambda row: bytes(row), screen)) - else: - return screen + ansi = AnsiParser(_ansi_stream(self)) + scr = ansi.draw_screen(returns, stop) + self.unget(ansi.buffer) + return scr def send(self, data: Union[str, bytes]) -> int: """Send raw data @@ -669,7 +655,7 @@ def thread_recv(flag: threading.Event): flag.set() except TimeoutError: - pass + pass # NOTE: We can ignore args since recv will never buffer except EOFError: logger.error("Receiver EOF") break From f41c08cc07acbc22ea13824929f9a2bcf8225080 Mon Sep 17 00:00:00 2001 From: Yudai Fujiwara Date: Sat, 11 May 2024 23:23:49 +0900 Subject: [PATCH 12/12] Fix module name --- tests/binary/encoding/test_locale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/binary/encoding/test_locale.py b/tests/binary/encoding/test_locale.py index 476c4dc..c82d9fb 100644 --- a/tests/binary/encoding/test_locale.py +++ b/tests/binary/encoding/test_locale.py @@ -1,5 +1,5 @@ import unittest -from ptrlib.binary.encoding.locale import * +from ptrlib.binary.encoding.char import * from logging import getLogger, FATAL