diff --git a/ptrlib/binary/encoding/__init__.py b/ptrlib/binary/encoding/__init__.py index 2fb0c06..d77b38f 100644 --- a/ptrlib/binary/encoding/__init__.py +++ b/ptrlib/binary/encoding/__init__.py @@ -1,6 +1,6 @@ from .ansi import * from .bitconv import * from .byteconv import * +from .char import * from .dump import * -from .locale import * from .table import * diff --git a/ptrlib/binary/encoding/ansi.py b/ptrlib/binary/encoding/ansi.py index 11d1ebd..41179fd 100644 --- a/ptrlib/binary/encoding/ansi.py +++ b/ptrlib/binary/encoding/ansi.py @@ -1,193 +1,697 @@ -import functools -import re - -try: - cache = functools.cache -except AttributeError: - cache = functools.lru_cache - - -@cache -def _escape_codes(): - codes = {} - # Cursor - codes['CSI_CURSOR_MOVE'] = re.compile(rb'^\x1b\[([1-9]\d*);([1-9]\d*)[Hf]') - codes['CSI_CURSOR_ROW'] = re.compile(rb'^\x1b\[([1-9]\d*)d') - codes['CSI_CURSOR_COLUMN'] = re.compile(rb'^\x1b\[([1-9]\d*)[`G]') - codes['CSI_CURSOR_UP'] = re.compile(rb'^\x1b\[(\d*)A') - codes['CSI_CURSOR_DOWN'] = re.compile(rb'^\x1b\[(\d*)B') - codes['CSI_CURSOR_RIGHT'] = re.compile(rb'^\x1b\[(\d*)C') - codes['CSI_CURSOR_LEFT'] = re.compile(rb'^\x1b\[(\d*)D') - codes['CSI_CURSOR_UP_HEAD'] = re.compile(rb'^\x1b\[(\d*)F') - codes['CSI_CURSOR_DOWN_HEAD'] = re.compile(rb'^\x1b\[(\d*)E') - codes['CSI_CURSOR_SAVE'] = re.compile(rb'^\x1b\[s') - codes['CSI_CURSOR_RESTORE'] = re.compile(rb'^\x1b\[u') - codes['CSI_CURSOR_REQUEST'] = re.compile(rb'^\x1b\[6n') - codes['FP_CURSOR_SAVE'] = re.compile(rb'^\x1b7') - codes['FP_CURSOR_RESTORE'] = re.compile(rb'^\x1b8') - codes['FE_CURSOR_ONEUP'] = re.compile(rb'^\x1bM') - - # Character - codes['CSI_CHAR_REPEAT'] = re.compile(rb'^\x1b\[(\d+)b') - - # Erase - codes['CSI_ERASE_DISPLAY_FORWARD'] = re.compile(rb'^\x1b\[[0]J') - codes['CSI_ERASE_DISPLAY_BACKWARD'] = re.compile(rb'^\x1b\[1J') - codes['CSI_ERASE_DISPLAY_ALL'] = re.compile(rb'^\x1b\[2J') - codes['CSI_ERASE_LINE_FORWARD'] = re.compile(rb'^\x1b\[[0]K') - codes['CSI_ERASE_LINE_BACKWARD'] = re.compile(rb'^\x1b\[1K') - codes['CSI_ERASE_LINE_ALL'] = re.compile(rb'^\x1b\[2K') - - # Others - codes['CSI_COLOR'] = re.compile(rb'^\x1b\[(\d+)m') - codes['CSI_MODE'] = re.compile(rb'^\x1b\[=(\d+)[hl]') - codes['CSI_PRIVATE_MODE'] = re.compile(rb'^\x1b\[?(\d+)[hl]') - - return codes - - -def draw_ansi(buf: bytes): - """Interpret ANSI code sequences to screen - - Args: - buf (bytes): ANSI code sequences - - Returns: - list: 2D array of screen to be drawn - """ - draw = [] - E = _escape_codes() - width = height = x = y = 0 - saved_dec = saved_sco = None - while len(buf): - if buf[0] == 13: # \r - x = 0 - buf = buf[1:] - continue - - elif buf[0] == 10: # \n - x = 0 - y += 1 - buf = buf[1:] - continue - - elif buf[0] != 0x1b: - if x >= width: width = x + 1 - if y >= height: height = y + 1 - draw.append(('PUTCHAR', x, y, buf[0])) - x += 1 - buf = buf[1:] - continue - - # CSI sequences - if m := E['CSI_CURSOR_MOVE'].match(buf): - y, x = int(m.group(1)) - 1, int(m.group(2)) - 1 - elif m := E['CSI_CURSOR_ROW'].match(buf): - y = int(m.group(1)) - 1 - elif m := E['CSI_CURSOR_COLUMN'].match(buf): - x = int(m.group(1)) - 1 - elif m := E['CSI_CURSOR_UP'].match(buf): - y = max(0, y - int(m.group(1))) if m.group(1) else max(0, y-1) - elif m := E['CSI_CURSOR_DOWN'].match(buf): - y += int(m.group(1)) if m.group(1) else 1 - elif m := E['CSI_CURSOR_LEFT'].match(buf): - x = max(0, x - int(m.group(1))) if m.group(1) else max(0, x-1) - elif m := E['CSI_CURSOR_RIGHT'].match(buf): - x += int(m.group(1)) if m.group(1) else 1 - elif m := E['CSI_CURSOR_UP_HEAD'].match(buf): - x, y = 0, max(0, y - int(m.group(1))) if m.group(1) else max(0, y-1) - elif m := E['CSI_CURSOR_DOWN_HEAD'].match(buf): - x, y = 0, y + int(m.group(1)) if m.group(1) else y+1 - elif m := E['CSI_CURSOR_SAVE'].match(buf): - saved_sco = (x, y) - elif m := E['CSI_CURSOR_RESTORE'].match(buf): - if saved_sco is not None: x, y = saved_sco - elif m := E['CSI_CURSOR_REQUEST'].match(buf): - pass # Not implemented: Request cursor position - elif m := E['CSI_COLOR'].match(buf): - pass # Not implemented: Change color - elif m := E['CSI_MODE'].match(buf): - pass # Not implemented: Set mode - - # Repease character - elif m := E['CSI_CHAR_REPEAT'].match(buf): - n = int(m.group(1)) - draw.append(('CSI_CHAR_REPEAT', x, y, n)) - x += n - - # Fe escape sequences - elif m := E['FE_CURSOR_ONEUP'].match(buf): - y = max(0, y - 1) # scroll not implemented - - # Fp escape sequences - elif m := E['FP_CURSOR_SAVE'].match(buf): - saved_dec = (x, y) - elif m := E['FP_CURSOR_RESTORE'].match(buf): - if saved_dec is not None: x, y = saved_dec - - # Operation +import enum +from logging import getLogger +from typing import Callable, Generator, List, Optional, Tuple, Union + +logger = getLogger(__name__) + + +# Based on https://bjh21.me.uk/all-escapes/all-escapes.txt +class AnsiOp(enum.Enum): + UNKNOWN = 0 + + # C0 Control Sequence + BEL = 0x10 + BS = enum.auto() + HT = enum.auto() + LF = enum.auto() + FF = enum.auto() + CR = enum.auto() + ESC = enum.auto() + + # Fe Escape Sequence + BPH = 0x20 # Break permitted here + NBH = enum.auto() # No break here + IND = enum.auto() # Index + NEL = enum.auto() # Next line + SSA = enum.auto() # Start of selected area + ESA = enum.auto() # End of selected area + HTS = enum.auto() # Character tabulation set + HTJ = enum.auto() # Character tabulation with justification + VTS = enum.auto() # Line tabulation set + PLD = enum.auto() # Partial line forward + PLU = enum.auto() # Partial line backward + RI = enum.auto() # Reverse line feed + SS2 = enum.auto() # Single-shift two + SS3 = enum.auto() # Single-shift three + DCS = enum.auto() # Device control string + PU1 = enum.auto() # Private use one + PU2 = enum.auto() # Private use two + STS = enum.auto() # Set transmit state + CCH = enum.auto() # Cancel character + MW = enum.auto() # Message waiting + SPA = enum.auto() # Start of guarded area + EPA = enum.auto() # End of guarded area + SOS = enum.auto() # Start of string + SCI = enum.auto() # Single character introducer + CSI = enum.auto() # Control sequence + + # Fp Private Control Functions + DECKPAM = 0x80 + + # CSI Sequence + ICH = 0x100 # Insert character + SBC = enum.auto() # Set border color + CUU = enum.auto() # Cursor up + SBP = enum.auto() # Set bell parameters + CUD = enum.auto() # Cursor down + SCR = enum.auto() # Set cursor parameters + CUF = enum.auto() # Cursor right + SBI = enum.auto() # Set background intensity + CUB = enum.auto() # Cursor left + SBB = enum.auto() # Set background blink bit + CNL = enum.auto() # Cursor next line + SNF = enum.auto() # Set normal foreground color + CPL = enum.auto() # Cursor preceding line + SNB = enum.auto() # Set normal background color + CHA = enum.auto() # Cursor character absolute + SRF = enum.auto() # Set reverse foreground color + CUP = enum.auto() # Cursor position + SRB = enum.auto() # Set reverse background color + CHT = enum.auto() # Cursor forward tabulation + ED = enum.auto() # Erase in page + SGF = enum.auto() # Set graphic foreground color + EL = enum.auto() # Erase in line + SGB = enum.auto() # Set graphic background color + IL = enum.auto() # Insert line + SEF = enum.auto() # Set emulator feature + DL = enum.auto() # Delete line + RAS = enum.auto() # Return attribute setting + EF = enum.auto() # Erase in field + EA = enum.auto() # Erase in area + DCH = enum.auto() # Delete character + SEE = enum.auto() # Select editing extent + CPR = enum.auto() # Active position report + SU = enum.auto() # Scroll up + SD = enum.auto() # Scroll down + NP = enum.auto() # Next page + PP = enum.auto() # Preceding page + CTC = enum.auto() # Cursor tabulation control + ECH = enum.auto() # Erase character + CVT = enum.auto() # Cursor line tabulation + CBT = enum.auto() # Cursor backward tabulation + SRS = enum.auto() # Start reversed string + PTX = enum.auto() # Parallel texts + SDS = enum.auto() # Start directed string + SIMD = enum.auto() # Select implicit movement direction + HPA = enum.auto() # Character position absolute + HPR = enum.auto() # Character position forward + REP = enum.auto() # Repeat + DA = enum.auto() # Device attributes + HSC = enum.auto() # Hide or show cursor + VPA = enum.auto() # Line position absolute + VPR = enum.auto() # Line position forward + HVP = enum.auto() # Character and line position + TBC = enum.auto() # Tabulation clear + PRC = enum.auto() # Print ROM character + SM = enum.auto() # Set mode + MC = enum.auto() # Media copy + HPB = enum.auto() # Character position backward + VPB = enum.auto() # Line position backward + RM = enum.auto() # Reset mode + CHC = enum.auto() # Clear and home cursor + SGR = enum.auto() # Select graphic rendition + SSM = enum.auto() # Set specific margin + DSR = enum.auto() # Device status report + DAQ = enum.auto() # Device area qualification + DECSSL = enum.auto() # Select set-up language + DECLL = enum.auto() # Load LEDs + DECSTBM = enum.auto() # Set top and bottom margins + RSM = enum.auto() # Reset margins + SCP = enum.auto() # Save cursor position + DECSLPP = enum.auto() # Set lines per physical page + RCP = enum.auto() # Reset cursor position + DECSVTS = enum.auto() # Set vertical tab stops + DECSHORP = enum.auto() # Set horizontal pitch + DGRTC = enum.auto() # Request terminal configuration + DECTST = enum.auto() # Invoke confidence test + SSW = enum.auto() # Screen switch + CAT = enum.auto() # Clear all tabs + + # SCS: Select character set + SCS_B = 0x200 # Default charset + SCS_0 = enum.auto() # DEC special charset + +class AnsiInstruction(object): + def __init__(self, + c0: AnsiOp, + code: Optional[AnsiOp]=None, + args: Optional[List[int]]=None): + self._c0 = c0 + self._code = code + self._args = args + + @property + def is_skip(self): + """Check if instruction can be skipped + + Returns: + bool: True if this instruction is not important for drawing screen + """ + return self._code in [ + AnsiOp.DECKPAM, + AnsiOp.DECSLPP, + AnsiOp.DECSTBM, + AnsiOp.SGR, + ] + + def __getitem__(self, i: int): + assert isinstance(i, int), "Slice must be integer" + if i < 0 or i >= len(self._args): + return None + else: + return self._args[i] + + def __eq__(self, other): + if isinstance(other, AnsiInstruction): + return self._c0 == other._c0 and \ + self._code == other._code and \ + self._args == other._args + + elif isinstance(other, AnsiOp): + return self._c0 == other or self._code == other + + else: + raise TypeError(f"Cannot compare AnsiInstruction and {type(other)}") + + def __neq__(self, other): + return not self.__eq__(other) + + def __str__(self): + return f'' + +class AnsiParser(object): + CTRL = [0x1b, 0x07, 0x08, 0x09, 0x0a, 0x0c, 0x0d] + ESC, BEL, BS, HT, LF, FF, CR = CTRL + + def __init__(self, + generator: Generator[bytes, None, None], + size: Tuple[int, int]=(0, 0), + pos: Tuple[int, int]=(0, 0)): + """ + Args: + generator: A generator which yields byte stream + size: Initial screen size (width, height) + pos: Initial cursor position (x, y) + """ + self._g = generator + self._buffer = b'' + self._width, self._height = size + self._x, self._y = pos + self._last_size = 0 + + @property + def buffer(self) -> bytes: + """Return contents of current buffering + """ + return self._buffer + + def _experimantal_warning(self, message: str): + logger.error(message) + logger.error("This feature is experimental and does not support some ANSI codes.\n" \ + "If you encounter this error, please create an issue here:\n" \ + "https://github.com/ptr-yudai/ptrlib/issues") + + def _decode_csi(self) -> Optional[AnsiInstruction]: + """Decode a CSI sequence + """ + c0, code = AnsiOp.ESC, AnsiOp.CSI + + # Parse parameters + mode_set, mode_q, mode_private = 0, 0, 0 + cur = 2 + args = [] + + while cur < len(self._buffer) and self._buffer[cur] in [ord('='), ord('?'), ord('>')]: + if self._buffer[cur] == ord('='): + mode_set = 1 + elif self._buffer[cur] == ord('?'): + mode_q = 1 + elif self._buffer[cur] == ord('>'): # TODO: Is this correct? + mode_private = 1 + else: + raise NotImplementedError("BUG: Unreachable path") + cur += 1 + + while True: + prev = cur + while cur < len(self._buffer) and 0x30 <= self._buffer[cur] <= 0x39: + cur += 1 + + if cur >= len(self._buffer): + self._last_size = len(self._buffer) + return None + + # NOTE: Common implementation seems to skip successive delimiters + if cur != prev: + args.append(int(self._buffer[prev:cur])) + + if self._buffer[cur] == ord(';'): + cur += 1 + else: + break + + # Check mnemonic + if self._buffer[cur] == ord('@'): + code = AnsiOp.ICH + default = (1,) + elif self._buffer[cur] == ord('A'): + code = [AnsiOp.CUU, AnsiOp.SBC][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('B'): + code = [AnsiOp.CUD, AnsiOp.SBP][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('C'): + code = [AnsiOp.CUF, AnsiOp.SCR][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('D'): + code = [AnsiOp.CUB, AnsiOp.SBI][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('E'): + code = [AnsiOp.CNL, AnsiOp.SBB][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('F'): + code = [AnsiOp.CPL, AnsiOp.SNF][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('G'): + code = [AnsiOp.CHA, AnsiOp.SNB][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('H'): + code = [AnsiOp.CUP, AnsiOp.SRF][mode_set] + default = [(1,1), ()][mode_set] + elif self._buffer[cur] == ord('I'): + # TODO: Support screen saver off + code = [AnsiOp.CHT, AnsiOp.SRB][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('J'): + # TODO: Support DECSED and screen saver on + code = [AnsiOp.ED, AnsiOp.SGF][mode_set] + default = [(0,), ()][mode_set] + elif self._buffer[cur] == ord('K'): + # TODO: Support DECSEL + code = [AnsiOp.EL, AnsiOp.SGB][mode_set] + default = [(0,), ()][mode_set] + elif self._buffer[cur] == ord('L'): + code = [AnsiOp.IL, AnsiOp.SEF][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('M'): + code = [AnsiOp.DL, AnsiOp.RAS][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('N'): + code, default = AnsiOp.EF, (0,) + default = (0,) + elif self._buffer[cur] == ord('O'): + code, default = AnsiOp.EA, (0,) + elif self._buffer[cur] == ord('P'): + code, default = AnsiOp.DCH, (1,) + elif self._buffer[cur] == ord('Q'): + code, default = AnsiOp.SEE, (0,) + elif self._buffer[cur] == ord('R'): + # TODO: Support DECXCPR + code, default = AnsiOp.CPR, (1, 1) + elif self._buffer[cur] == ord('S'): + code, default = AnsiOp.SU, (1,) + elif self._buffer[cur] == ord('T'): + # TODO: Support initiate hilite mouse tracking + code, default = AnsiOp.SD, (1,) + elif self._buffer[cur] == ord('U'): + code, default = AnsiOp.NP, (1,) + elif self._buffer[cur] == ord('V'): + code, default = AnsiOp.PP, (1,) + elif self._buffer[cur] == ord('W'): + # TODO: Support DECST8C + code, default = AnsiOp.CTC, (0,) + elif self._buffer[cur] == ord('X'): + code, default = AnsiOp.ECH, (1,) + elif self._buffer[cur] == ord('Y'): + code, default = AnsiOp.CVT, (1,) + elif self._buffer[cur] == ord('Z'): + code, default = AnsiOp.CBT, (1,) + elif self._buffer[cur] == ord('['): + # TODO: Support ignore next character + code, default = AnsiOp.SRS, (0,) + elif self._buffer[cur] == ord('\\'): + code, default = AnsiOp.PTX, (0,) + elif self._buffer[cur] == ord(']'): + # TODO: Support linux private sequences + code, default = AnsiOp.SDS, (0,) + elif self._buffer[cur] == ord('^'): + code, default = AnsiOp.SIMD, (0,) + elif self._buffer[cur] == ord('`'): + code, default = AnsiOp.HPA, (1,) + elif self._buffer[cur] == ord('a'): + code, default = AnsiOp.HPR, (1,) + elif self._buffer[cur] == ord('b'): + code, default = AnsiOp.REP, (1,) + elif self._buffer[cur] == ord('c'): + # NOTE: This operation has a lot of meanings + code = [AnsiOp.DA, AnsiOp.HSC][mode_set] + default = [(0,), ()][mode_set] + elif self._buffer[cur] == ord('d'): + code, default = AnsiOp.VPA, (1,) + elif self._buffer[cur] == ord('e'): + code, default = AnsiOp.VPR, (1,) + elif self._buffer[cur] == ord('f'): + code, default = AnsiOp.HVP, (1, 1) + elif self._buffer[cur] == ord('g'): + # TODO: Support reset tabs + code = [AnsiOp.TBC, AnsiOp.PRC][mode_set] + default = [(0,), ()][mode_set] + elif self._buffer[cur] == ord('h'): + code, default = AnsiOp.SM, () + elif self._buffer[cur] == ord('i'): + code, default = AnsiOp.MC, () + elif self._buffer[cur] == ord('j'): + code, default = AnsiOp.HPB, (1,) + elif self._buffer[cur] == ord('k'): + code, default = AnsiOp.VPB, (1,) + elif self._buffer[cur] == ord('l'): + # TODO: Support insert line up + code = [AnsiOp.RM, AnsiOp.CHC][mode_set] + default = [(1,), ()][mode_set] + elif self._buffer[cur] == ord('m'): + # TODO: Support delete line down + code = [AnsiOp.SGR, AnsiOp.SSM][mode_set] + default = [(0,), ()][mode_set] + elif self._buffer[cur] == ord('n'): + code, default = AnsiOp.DSR, (0,) + elif self._buffer[cur] == ord('o'): + code, default = AnsiOp.DAQ, (0,) + elif self._buffer[cur] == ord('p'): + code, default = AnsiOp.DECSSL, () + elif self._buffer[cur] == ord('q'): + code, default = AnsiOp.DECLL, () + elif self._buffer[cur] == ord('r'): + # TODO: Support CSR and SUNSCRL + code = [AnsiOp.DECSTBM, AnsiOp.RSM][mode_set] + default = [(), ()][mode_set] + elif self._buffer[cur] == ord('s'): + code, default = AnsiOp.SCP, () + elif self._buffer[cur] == ord('t'): + code, default = AnsiOp.DECSLPP, () + elif self._buffer[cur] == ord('u'): + code, default = AnsiOp.RCP, () + elif self._buffer[cur] == ord('v'): + code, default = AnsiOp.DECSVTS, () + elif self._buffer[cur] == ord('w'): + code, default = AnsiOp.DECSHORP, () + elif self._buffer[cur] == ord('x'): + code, default = AnsiOp.DGRTC, () + elif self._buffer[cur] == ord('y'): + code, default = AnsiOp.DECTST, () + elif self._buffer[cur] == ord('z'): + # TODO: Support + code = [AnsiOp.SSW, AnsiOp.CAT][mode_set] + default = [(), ()][mode_set] + else: + self._experimantal_warning(f"CSI not implemented: {self._buffer[cur-2:cur+0x10]}") + raise NotImplementedError("Unknown CSI") + + if len(args) < len(default): + args = tuple(args + list(default[len(args):])) + + self._buffer = self._buffer[cur+1:] + return AnsiInstruction(c0, code, args) + + def _decode_esc(self) -> Optional[AnsiInstruction]: + """Decode an ESC sequence + """ + c0 = AnsiOp.ESC + code = AnsiOp.UNKNOWN + + cur = 1 + if len(self._buffer) <= cur: + self._last_size = len(self._buffer) + return None + + if self._buffer[cur] == ord('['): + cur += 1 + if self._buffer[cur] == ord('B'): + code = AnsiOp.BPH + elif self._buffer[cur] == ord('C'): + code = AnsiOp.NBH + elif self._buffer[cur] == ord('D'): + code = AnsiOp.IND + elif self._buffer[cur] == ord('E'): + code = AnsiOp.NEL + elif self._buffer[cur] == ord('F'): + code = AnsiOp.SSA + elif self._buffer[cur] == ord('G'): + code = AnsiOp.ESA + elif self._buffer[cur] == ord('H'): + code = AnsiOp.HTS + elif self._buffer[cur] == ord('I'): + code = AnsiOp.HTJ + elif self._buffer[cur] == ord('J'): + code = AnsiOp.VTS + elif self._buffer[cur] == ord('K'): + code = AnsiOp.PLD + elif self._buffer[cur] == ord('L'): + code = AnsiOp.PLU + elif self._buffer[cur] == ord('M'): + code = AnsiOp.RI + elif self._buffer[cur] == ord('N'): + code = AnsiOp.SS2 + elif self._buffer[cur] == ord('O'): + code = AnsiOp.SS3 + elif self._buffer[cur] == ord('P'): + code = AnsiOp.DCS + elif self._buffer[cur] == ord('Q'): + code = AnsiOp.PU1 + elif self._buffer[cur] == ord('R'): + code = AnsiOp.PU2 + elif self._buffer[cur] == ord('S'): + code = AnsiOp.STS + elif self._buffer[cur] == ord('T'): + code = AnsiOp.CCH + elif self._buffer[cur] == ord('U'): + code = AnsiOp.MW + elif self._buffer[cur] == ord('V'): + code = AnsiOp.SPA + elif self._buffer[cur] == ord('W'): + code = AnsiOp.EPA + elif self._buffer[cur] == ord('X'): + code = AnsiOp.SOS + elif self._buffer[cur] == ord('Z'): + code = AnsiOp.SCI + else: + return self._decode_csi() + + elif self._buffer[cur] == ord('('): + cur += 1 + if len(self._buffer) <= cur: + self._last_size = len(self._buffer) + return None + + if self._buffer[cur] == ord('B'): + code = AnsiOp.SCS_B + elif self._buffer[cur] == ord('0'): + code = AnsiOp.SCS_0 + else: + self._experimantal_warning(f"ESC not implemented: {self._buffer[cur-2:cur+0x10]}") + raise NotImplementedError(f"Unknown ESC") + + elif self._buffer[cur] == ord('='): + code = AnsiOp.DECKPAM + else: - for k in ['CSI_ERASE_DISPLAY_FORWARD', - 'CSI_ERASE_DISPLAY_BACKWARD', - 'CSI_ERASE_DISPLAY_ALL', - 'CSI_ERASE_LINE_FORWARD', - 'CSI_ERASE_LINE_BACKWARD', - 'CSI_ERASE_LINE_ALL']: - if m := E[k].match(buf): - if k == 'CSI_ERASE_DISPLAY_ALL': - draw = [] - else: - draw.append((k, x, y, None)) - break - - # Otherwise draw text - if m: - buf = buf[m.end():] + self._experimantal_warning(f"ESC not implemented: {self._buffer[cur-2:cur+0x10]}") + raise NotImplementedError(f"Unknown ESC") + + self._buffer = self._buffer[cur+1:] + return AnsiInstruction(c0, code) + + def parse_block(self) -> Optional[Union[bytes, AnsiInstruction]]: + """Parse a block of ANSI escape sequence + + Returns: + AnsiInstruction: Instruction, or None if need more data + + Raises: + StopIteration: No more data to receive + """ + if len(self._buffer) <= self._last_size: + try: + self._buffer += next(self._g) + except StopIteration as e: + if len(self._buffer) == 0: + # All processed, end of input + raise e from None + + self._last_size = 0 + + # TODO: Support C1 control code + if self._buffer[0] not in AnsiParser.CTRL: + # Return until a control code appears + for i, c in enumerate(self._buffer): + if c in AnsiParser.CTRL: + data, self._buffer = self._buffer[:i], self._buffer[i:] + return data + + data, self._buffer = self._buffer, b'' + return data + + # Check C0 control sequence + if self._buffer[0] == AnsiParser.BEL: # BEL + instr = AnsiInstruction(AnsiOp.BEL) + elif self._buffer[0] == AnsiParser.BS: # BS + instr = AnsiInstruction(AnsiOp.BS) + elif self._buffer[0] == AnsiParser.HT: # HT + instr = AnsiInstruction(AnsiOp.HT) + elif self._buffer[0] == AnsiParser.LF: # LF + instr = AnsiInstruction(AnsiOp.LF) + elif self._buffer[0] == AnsiParser.FF: # FF + instr = AnsiInstruction(AnsiOp.FF) + elif self._buffer[0] == AnsiParser.CR: # CR + instr = AnsiInstruction(AnsiOp.CR) + else: + return self._decode_esc() + + self._buffer = self._buffer[1:] + return instr + + def _update_screen_size(self, screen): + if len(screen) == 0: + return + self._width = max(map(lambda pos: pos[0], screen.keys())) + 1 + self._height = max(map(lambda pos: pos[1], screen.keys())) + 1 + + def _special_char(self, charset: AnsiOp, c: int): + if charset == AnsiOp.SCS_B: + return c + + elif charset == AnsiOp.SCS_0: + if 0x5f <= c <= 0x7e: + return [0x20, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x6f, + 0x2b, 0x3f, 0x3f, 0x2b, 0x2b, 0x2b, 0x2b, 0x2b, + 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x2b, 0x2b, 0x2b, + 0x2b, 0x7c, 0x3c, 0x3e, 0x6e, 0x3d, 0x66, 0x2e][c - 0x5f] + else: + return c + + else: + self._experimantal_warning(f"Character set not implemented: {charset}") + raise NotImplementedError("Unknown character set") + + def draw_screen(self, + returns: type=list, + stop: Optional[Callable[[AnsiInstruction], bool]]=None) -> list: + """Receive a screen + + Args: + returns: Either str or list + stop: Function to determine when to stop emulating instructions + """ + if stop is None: + # Default stop checker designed for ncurses games + stop = lambda instr: instr == AnsiOp.HTS + + # NOTE: These variables are global so that we can support + # successive draws in the future + self._width = self._height = 0 + + screen = {} + charset = AnsiOp.SCS_B + DEL = 0x20 # Empty + last_char = DEL + stop_recv = False + while not stop_recv: + instr = None + try: + while instr is None: + instr = self.parse_block() + except StopIteration: + break + + stop_recv = stop(instr) + + if isinstance(instr, bytes): + # TODO: Reverse order? + for c in instr: + screen[(self._x, self._y)] = self._special_char(charset, c) + self._x += 1 + last_char = c + + else: + if instr.is_skip: + continue + + elif instr == AnsiOp.SCS_B: # English mode + charset = AnsiOp.SCS_B + + elif instr == AnsiOp.BS: # Back space + self._x = max(0, self._x - 1) + stop_recv = True + + elif instr == AnsiOp.CHA: # Cursor character absolute + self._x = instr[0] - 1 + + elif instr == AnsiOp.SCS_0: # DEC special graphic + charset = AnsiOp.SCS_0 + + elif instr == AnsiOp.CR: # Carriage return + self._x, self._y = 0, self._y + 1 + + elif instr == AnsiOp.CUP: # Cursor position + self._x, self._y = instr[1] - 1, instr[0] - 1 + + elif instr == AnsiOp.ECH: # Erase character + for x in range(self._x, self._x + instr[0]): + screen[(x, self._y)] = DEL + + elif instr == AnsiOp.ED: # Erase in page + self._update_screen_size(screen) + if instr[0] == 0: + for y in range(self._y, self._height): + screen[(self._x, y)] = DEL + elif instr[0] == 1: + for y in range(self._y + 1): + screen[(self._x, y)] = DEL + elif instr[0] == 2: + for y in range(self._height): + screen[(self._x, y)] = DEL + + elif instr == AnsiOp.EL: # Erase in line + self._update_screen_size(screen) + if instr[0] == 0: + for x in range(self._x, self._width): + screen[(x, self._y)] = DEL + elif instr[0] == 1: + for x in range(self._x + 1): + screen[(x, self._y)] = DEL + elif instr[0] == 2: + for x in range(self._width): + screen[(x, self._y)] = DEL + + elif instr == AnsiOp.HTS: + self._x, self._y = 0, 0 + + elif instr == AnsiOp.LF: + self._x, self._y = 0, self._y + 1 + + elif instr == AnsiOp.REP: # Repeat + for x in range(self._x, self._x + instr[0]): + screen[(x, self._y)] = self._special_char(charset, last_char) + self._x += instr[0] + + elif instr == AnsiOp.RM: # Reset mode + pass # TODO: ? + + elif instr == AnsiOp.SM: # Set mode + pass # TODO: ? + + elif instr == AnsiOp.VPA: # Line position absolute + self._y = instr[0] - 1 + + else: + raise ValueError(f"Emulation not supported for instruction {instr}") + + self._update_screen_size(screen) + field = [[' ' for x in range(self._width)] + for y in range(self._height)] + for (x, y) in screen: + field[y][x] = chr(screen[(x, y)]) + + if returns == list: + return field else: - # TODO: skip ESC only? - raise NotImplementedError(f"Could not interpret code: {buf[:10]}") - - # Emualte drawing - screen = [[' ' for x in range(width)] for y in range(height)] - last_char = ' ' - for op, x, y, attr in draw: - if op == 'PUTCHAR': - last_char = chr(attr) - screen[y][x] = last_char - - elif op == 'CSI_CHAR_REPEAT': - for j in range(attr): - screen[y][x+j] = last_char - - elif op == 'CSI_ERASE_DISPLAY_FORWARD': - for j in range(x, width): - screen[y][j] = ' ' - for i in range(y+1, height): - for j in range(width): - screen[i][j] = ' ' - - elif op == 'CSI_ERASE_DISPLAY_BACKWARD': - for j in range(x): - screen[y][j] = ' ' - for i in range(y): - for j in range(width): - screen[i][j] = ' ' - - elif op == 'CSI_ERASE_DISPLAY_ALL': - for i in range(height): - for j in range(width): - screen[i][j] = ' ' - - elif op == 'CSI_ERASE_LINE_FORWARD': - for j in range(x, width): - screen[y][j] = ' ' - - elif op == 'CSI_ERASE_LINE_BACKWARD': - for j in range(x): - screen[y][j] = ' ' - - elif op == 'CSI_ERASE_LINE_ALL': - for j in range(width): - screen[y][j] = ' ' - - return screen + return '\n'.join(map(lambda line: ''.join(line), field)) diff --git a/ptrlib/binary/encoding/byteconv.py b/ptrlib/binary/encoding/byteconv.py index ac0f587..7613f90 100644 --- a/ptrlib/binary/encoding/byteconv.py +++ b/ptrlib/binary/encoding/byteconv.py @@ -9,6 +9,8 @@ def bytes2str(data: bytes) -> str: """ if isinstance(data, bytes): return ''.join(list(map(chr, data))) + elif isinstance(data, str): + return data # Fallback else: raise ValueError("{} given ('bytes' expected)".format(type(data))) @@ -20,6 +22,8 @@ def str2bytes(data: str) -> bytes: return bytes(list(map(ord, data))) except ValueError: return data.encode('utf-8') + elif isinstance(data, bytes): + return data # Fallback else: raise ValueError("{} given ('str' expected)".format(type(data))) diff --git a/ptrlib/binary/encoding/locale.py b/ptrlib/binary/encoding/char.py similarity index 100% rename from ptrlib/binary/encoding/locale.py rename to ptrlib/binary/encoding/char.py diff --git a/ptrlib/connection/__init__.py b/ptrlib/connection/__init__.py index 8706d43..beb4a72 100644 --- a/ptrlib/connection/__init__.py +++ b/ptrlib/connection/__init__.py @@ -1,4 +1,3 @@ -from .proc import * -from .sock import * +from .proc import Process, process +from .sock import Socket, remote from .ssh import * -from .winproc import * diff --git a/ptrlib/connection/proc.py b/ptrlib/connection/proc.py index bed3244..2a69c2a 100644 --- a/ptrlib/connection/proc.py +++ b/ptrlib/connection/proc.py @@ -1,15 +1,13 @@ -# coding: utf-8 -from logging import getLogger -from typing import Any, List, Mapping -from ptrlib.arch.linux.sig import * -from ptrlib.binary.encoding import * -from .tube import * -from .winproc import * -import errno -import select import os +import select import subprocess -import time +from logging import getLogger +from typing import List, Mapping, Optional, Union +from ptrlib.arch.linux.sig import signal_name +from ptrlib.binary.encoding import bytes2str +from .tube import Tube, tube_is_open +from .winproc import WinProcess + _is_windows = os.name == 'nt' if not _is_windows: @@ -19,208 +17,245 @@ logger = getLogger(__name__) - class UnixProcess(Tube): - def __init__( - self, - args: Union[Union[bytes, str], List[Union[bytes, str]]], - env: Optional[Union[Mapping[bytes, Union[bytes, str]], Mapping[str, Union[bytes, str]]]]=None, - cwd: Optional[Union[bytes, str]]=None, - timeout: Optional[int]=None - ): - """Create a process - - Create a new process and make a pipe. + # + # Constructor + # + def __init__(self, + args: Union[bytes, str, List[Union[bytes, str]]], + env: Optional[Union[Mapping[bytes, Union[bytes, str]], Mapping[str, Union[bytes, str]]]]=None, + cwd: Optional[Union[bytes, str]]=None, + shell: Optional[bool]=None, + raw: bool=False, + stdin : Optional[int]=None, + stdout: Optional[int]=None, + stderr: Optional[int]=None, + **kwargs): + """Create a UNIX process + + Create a UNIX process and make a pipe. Args: - args (list): The arguments to pass - env (list) : The environment variables + args : The arguments to pass + env : The environment variables + cwd : Working directory + shell : If true, `args` is a shell command + raw : Disable pty if this parameter is true + stdin : File descriptor of standard input + stdout : File descriptor of standard output + stderr : File descriptor of standard error Returns: - Process: ``Process`` instance. + Process: ``Process`` instance + + Examples: + ``` + p = Process("/bin/ls", cwd="/tmp") + p = Process(["wget", "www.example.com"], + stderr=subprocess.DEVNULL) + p = Process("cat /proc/self/maps", env={"LD_PRELOAD": "a.so"}) + ``` """ - assert not _is_windows - super().__init__() + assert not _is_windows, "UnixProcess cannot work on Windows" + assert isinstance(args, (str, bytes, list)), \ + "`args` must be either str, bytes, or list" + assert env is None or isinstance(env, dict), \ + "`env` must be a dictionary" + assert cwd is None or isinstance(cwd, (str, bytes)), \ + "`cwd` must be either str or bytes" + + # NOTE: We need to initialize _current_timeout before super constructor + # because it may call _settimeout_impl + self._current_timeout = 0 + super().__init__(**kwargs) + + # Guess shell mode based on args + if shell is None: + if isinstance(args, (str, bytes)): + args = [bytes2str(args)] + if ' ' in args[0]: + shell = True + logger.info("Detected whitespace in arguments: " \ + "`shell=True` enabled") + else: + shell = False + else: + shell = False - if isinstance(args, list): - self.args = args - self.filepath = args[0] else: - self.args = [args] - self.filepath = args - self.env = env - self.default_timeout = timeout - self.timeout = self.default_timeout - self.proc = None - self.returncode = None - - # Open pty on Unix - master, self.slave = pty.openpty() - tty.setraw(master) - tty.setraw(self.slave) - - # Create a new process + if isinstance(args, (str, bytes)): + args = [bytes2str(args)] + else: + args = list(map(bytes2str, args)) + + # Prepare stdio + master = self._slave = None + if not raw: + master, self._slave = pty.openpty() + tty.setraw(master) + tty.setraw(self._slave) + stdout = self._slave + + if stdin is None: stdin = subprocess.PIPE + if stdout is None: stdout = subprocess.PIPE + if stderr is None: stderr = subprocess.STDOUT + + # Open process + assert isinstance(shell, bool), "`shell` must be boolean" try: - self.proc = subprocess.Popen( - self.args, - cwd = cwd, - env = self.env, - shell = False, - stdout=self.slave, - stderr=subprocess.STDOUT, - stdin=subprocess.PIPE + self._proc = subprocess.Popen( + args, cwd=cwd, env=env, + shell=shell, + stdin=stdin, + stdout=stdout, + stderr=stderr, ) - except FileNotFoundError: - logger.warning("Executable not found: '{0}'".format(self.filepath)) - return + except FileNotFoundError as err: + logger.error(f"Could not execute {args[0]}") + raise err from None + + self._filepath = args[0] + + self._returncode = None # Duplicate master if master is not None: - self.proc.stdout = os.fdopen(os.dup(master), 'r+b', 0) + self._proc.stdout = os.fdopen(os.dup(master), 'r+b', 0) os.close(master) # Set in non-blocking mode - fd = self.proc.stdout.fileno() + fd = self._proc.stdout.fileno() fl = fcntl.fcntl(fd, fcntl.F_GETFL) fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK) - logger.info("Successfully created new process (PID={})".format(self.proc.pid)) - - def _settimeout(self, timeout: Optional[Union[int, float]]): - if timeout is None: - self.timeout = self.default_timeout - elif timeout > 0: - self.timeout = timeout - - def _socket(self) -> Optional[Any]: - return self.proc - - def _poll(self) -> Optional[int]: - if self.proc is None: - return False - - # Check if the process exits - self.proc.poll() - returncode = self.proc.returncode - if returncode is not None and self.returncode is None: - self.returncode = returncode - name = signal_name(-returncode, detail=True) - if name: name = '--> ' + name - logger.error( - "Process '{}' (pid={}) stopped with exit code {} {}".format( - self.filepath, self.proc.pid, returncode, name - )) - return returncode - - def is_alive(self) -> bool: - """Check if the process is alive""" - return self._poll() is None - - def _can_recv(self) -> bool: - """Check if receivable""" - if self.proc is None: - return False - try: - r = select.select( - [self.proc.stdout], [], [], self.timeout - ) - if r == ([], [], []): - raise TimeoutError("Receive timeout", b'') - else: - # assert r == ([self.proc.stdout], [], []) - return True - except TimeoutError as e: - raise e from None - except select.error as v: - if v[0] == errno.EINTR: - return False - assert False, "unreachable" - - def _recv(self, size: int=4096, timeout: Optional[Union[int, float]]=None) -> bytes: + logger.info(f"Successfully created new process {str(self)}") + self._init_done = True + + # + # Properties + # + @property + def returncode(self) -> Optional[int]: + return self._returncode + + @property + def pid(self) -> int: + return self._proc.pid + + # + # Implementation of Tube methods + # + def _settimeout_impl(self, timeout: Union[int, float]): + self._current_timeout = timeout + + def _recv_impl(self, size: int) -> bytes: """Receive raw data - Receive raw data of maximum `size` bytes length through the pipe. + Receive raw data of maximum `size` bytes through the pipe. Args: - size (int): The data size to receive - timeout (int): Timeout (in second) + size: Data size to receive Returns: bytes: The received data """ - self._settimeout(timeout) + if self._current_timeout == 0: + timeout = None + else: + timeout = self._current_timeout - if not self._can_recv(): - return b'' + ready, [], [] = select.select( + [self._proc.stdout.fileno()], [], [], timeout + ) + if len(ready) == 0: + raise TimeoutError("Timeout (_recv_impl)", b'') from None try: - data = self.proc.stdout.read(size) + data = self._proc.stdout.read(size) except subprocess.TimeoutExpired: - # TODO: Unreachable? - raise TimeoutError("Receive timeout", b'') from None + raise TimeoutError("Timeout (_recv_impl)", b'') from None - self._poll() # poll after received all data return data - def _send(self, data: Union[str, bytes]): + def _send_impl(self, data: bytes) -> int: """Send raw data - Send raw data through the socket - - Args: - data (bytes) : Data to send + Raises: + ConnectionAbortedError: Connection is aborted by process + ConnectionResetError: Connection is closed by peer + TimeoutError: Timeout exceeded + OSError: System error """ - self._poll() - if isinstance(data, str): - data = str2bytes(data) - elif not isinstance(data, bytes): - logger.warning("Expected 'str' or 'bytes' but {} given".format( - type(data) - )) - try: - self.proc.stdin.write(data) - self.proc.stdin.flush() - except IOError: - logger.warning("Broken pipe") + n_written = self._proc.stdin.write(data) + self._proc.stdin.flush() + return n_written + except IOError as err: + logger.error(f"Broken pipe: {str(self)}") + raise err from None + + def _shutdown_recv_impl(self): + """Close stdin + """ + self._proc.stdout.close() + self._proc.stderr.close() - def close(self): - """Close the socket + def _shutdown_send_impl(self): + """Close stdout + """ + self._proc.stdin.close() - Close the socket. - This method is called from the destructor. + def _close_impl(self): + """Close process """ - if self.proc: - os.close(self.slave) - self.proc.stdin.close() - self.proc.stdout.close() - if self.is_alive(): - self.proc.kill() - self.proc.wait() - logger.info("'{0}' (PID={1}) killed".format(self.filepath, self.proc.pid)) - self.proc = None - else: - logger.info("'{0}' (PID={1}) has already exited".format(self.filepath, self.proc.pid)) - self.proc = None + if self._is_alive_impl(): + self._proc.kill() + self._proc.wait() + logger.info(f"{str(self)} killed by `close`") + + if self._slave is not None: # PTY mode + os.close(self._slave) + self._slave = None + + if self._proc.stdin is not None: + self._proc.stdin.close() + if self._proc.stdout is not None: + self._proc.stdout.close() + if self._proc.stderr is not None: + self._proc.stderr.close() + + def _is_alive_impl(self) -> bool: + """Check if the process is alive""" + return self.poll() is None - def shutdown(self, target: Literal['send', 'recv']): - """Kill one connection + def __str__(self) -> str: + return f"'{self._filepath}' (PID={self._proc.pid})" - Close send/recv pipe. - Args: - target (str): Connection to close (`send` or `recv`) + # + # Custom method + # + def poll(self) -> Optional[int]: + """Check if the process has exited """ - if target in ['write', 'send', 'stdin']: - self.proc.stdin.close() + if self._proc.poll() is None: + return None - elif target in ['read', 'recv', 'stdout', 'stderr']: - self.proc.stdout.close() + if self._returncode is None: + # First time to detect process exit + self._returncode = self._proc.returncode + name = signal_name(-self._returncode, detail=True) + if name: + name = ' --> ' + name - else: - logger.error("You must specify `send` or `recv` as target.") + logger_func = logger.info if self._returncode == 0 else logger.error + logger_func(f"{str(self)} stopped with exit code " \ + f"{self._returncode}{name}") + + return self._returncode - def wait(self) -> int: + @tube_is_open + def wait(self, timeout: Optional[Union[int, float]]=None) -> int: """Wait until the process dies Wait until the process exits and get the status code. @@ -228,12 +263,8 @@ def wait(self) -> int: Returns: code (int): Status code of the process """ - while self.is_alive(): - time.sleep(0.1) - return self.returncode + return self._proc.wait(timeout) - def __del__(self): - self.close() Process = WinProcess if _is_windows else UnixProcess -process = Process # alias for the Process +process = Process # alias for the Process diff --git a/ptrlib/connection/sock.py b/ptrlib/connection/sock.py index c364a93..5386587 100644 --- a/ptrlib/connection/sock.py +++ b/ptrlib/connection/sock.py @@ -1,172 +1,255 @@ -# coding: utf-8 -from logging import getLogger - -from ptrlib.binary.encoding import * -from .tube import * +import errno +import select import socket +from logging import getLogger +from typing import Optional, Union +from ptrlib.binary.encoding import bytes2str +from .tube import Tube, tube_is_open logger = getLogger(__name__) class Socket(Tube): - def __init__(self, host: Union[str, bytes], port: Optional[int]=None, - timeout: Optional[Union[int, float]]=None, - ssl: bool=False, sni: Union[str, bool]=True): + # + # Constructor + # + def __init__(self, + host: Union[str, bytes], + port: Optional[int]=None, + ssl: bool=False, + sni: Union[str, bool]=True, + **kwargs): """Create a socket Create a new socket and establish a connection to the host. Args: - host (str): The host name or ip address of the server - port (int): The port number + host: Host name or ip address + port: Port number + ssl : Enable SSL/TLS + sni : SNI Returns: Socket: ``Socket`` instance. """ - super().__init__() + assert isinstance(host, (str, bytes)), \ + "`host` must be either str or bytes" - if isinstance(host, bytes): - host = bytes2str(host) + # NOTE: We need to initialize _current_timeout before super constructor + # because it may call _settimeout_impl + self._current_timeout = 0 + super().__init__(**kwargs) + # Interpret host name and port number + host = bytes2str(host) if port is None: host = host.strip() if host.startswith('nc '): _, a, b = host.split() - host, port = a, int(b) elif host.count(':') == 1: a, b = host.split(':') - host, port = a, int(b) elif host.count(' ') == 1: a, b = host.split() - host, port = a, int(b) else: - raise ValueError("Specify port number") + raise ValueError("Port number is not given") + host, port = a, int(b) + + else: + port = int(port) + + self._host = host + self._port = port - self.host = host - self.port = port - self.timeout = timeout # Create a new socket - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if ssl: import ssl as _ssl self.context = _ssl.SSLContext(_ssl.PROTOCOL_TLS_CLIENT) self.context.check_hostname = False self.context.verify_mode = _ssl.CERT_NONE if sni is True: - self.sock = self.context.wrap_socket(self.sock) + self._sock = self.context.wrap_socket(self._sock) else: - self.sock = self.context.wrap_socket(self.sock, server_hostname=sni) + self._sock = self.context.wrap_socket(self._sock, server_hostname=sni) + # Establish a connection try: - self.sock.connect((self.host, self.port)) - logger.info("Successfully connected to {0}:{1}".format(self.host, self.port)) + self._sock.connect((self._host, self._port)) + logger.info(f"Successfully connected to {self._host}:{self._port}") + except ConnectionRefusedError as e: - err = "Connection to {0}:{1} refused".format(self.host, self.port) - logger.warning(err) + logger.error(f"Connection to {self._host}:{self._port} refused") raise e from None - def _settimeout(self, timeout: Optional[Union[int, float]]): - if timeout is None: - self.sock.settimeout(self.timeout) - elif timeout > 0: - self.sock.settimeout(timeout) + self._init_done = True + + # + # Implementation of Tube methods + # + def _settimeout_impl(self, + timeout: Union[int, float]): + """Set timeout - def _socket(self) -> Optional[socket.socket]: - return self.sock + Args: + timeout: Timeout in second + """ + self._current_timeout = timeout - def _recv(self, size: int=4096, timeout: Optional[Union[int, float]]=None) -> bytes: + def _recv_impl(self, size: int) -> bytes: """Receive raw data Receive raw data of maximum `size` bytes length through the socket. Args: - size (int): The data size to receive - timeout (int): Timeout (in second) + size: Maximum data size to receive at once Returns: bytes: The received data + + Raises: + ConnectionAbortedError: Connection is aborted by process + ConnectionResetError: Connection is closed by peer + TimeoutError: Timeout exceeded + OSError: System error """ - self._settimeout(timeout) + # NOTE: We cannot rely on the blocking behavior of `recv` + # because the socket might be non-blocking mode + # due to `_is_alive_impl` on multi-thread environment. + if self._current_timeout == 0: + timeout = None + else: + timeout = self._current_timeout + + ready, [], [] = select.select([self._sock], [], [], timeout) + if len(ready) == 0: + raise TimeoutError("Timeout (_recv_impl)", b'') from None try: - data = self.sock.recv(size) + data = self._sock.recv(size) + if len(data) == 0: + raise ConnectionResetError("Empty reply") from None + + except BlockingIOError: + # NOTE: This exception can occur if this method is called + # while `_is_alive_impl` is running in multi-thread. + # We make `_recv_impl` fail in this case. + return b'' + except socket.timeout: - raise TimeoutError("Receive timeout", b'') from None + raise TimeoutError("Timeout (_recv_impl)", b'') from None + except ConnectionAbortedError as e: - logger.warning("Connection aborted by the host") + logger.error("Connection aborted") raise e from None + except ConnectionResetError as e: - logger.warning("Connection reset by the host") + logger.error(f"Connection reset by {str(self)}") + raise e from None + + except OSError as e: + logger.error("OS Error") raise e from None return data - def _send(self, data: Union[str, bytes]): + def _send_impl(self, data: bytes) -> int: """Send raw data - Send raw data through the socket - - Args: - data (bytes) : Data to send + Raises: + ConnectionAbortedError: Connection is aborted by process + ConnectionResetError: Connection is closed by peer + TimeoutError: Timeout exceeded + OSError: System error """ - if isinstance(data, str): - data = str2bytes(data) - elif not isinstance(data, bytes): - logger.warning("Expected 'str' or 'bytes' but {} given".format( - type(data) - )) - try: - self.sock.send(data) + return self._sock.send(data) + except BrokenPipeError as e: - logger.warning("Broken pipe") + logger.error("Broken pipe") raise e from None + except ConnectionAbortedError as e: - logger.warning("Connection aborted by the host") + logger.error("Connection aborted") + raise e from None + + except ConnectionResetError as e: + logger.error(f"Connection reset by {str(self)}") + raise e from None + + except OSError as e: + logger.error("OS Error") raise e from None - def close(self): - """Close the socket + def _close_impl(self): + """Close socket + """ + self._sock.close() + logger.info(f"Connection to {str(self)} closed") - Close the socket. - This method is called from the destructor. + def _is_alive_impl(self) -> bool: + """Check if socket is alive """ - if self.sock: - self.sock.close() - self.sock = None - logger.info("Connection to {0}:{1} closed".format(self.host, self.port)) + try: + # Save timeout value since non-blocking mode will clear it + timeout = self._sock.gettimeout() + self._sock.setblocking(False) - def shutdown(self, target: Literal['send', 'recv']): - """Kill one connection + # Connection is closed if recv returns empty buffer + ret = len(self._sock.recv(1, socket.MSG_PEEK)) == 1 - Close send/recv socket. + except BlockingIOError as e: + ret = True - Args: - target (str): Connection to close (`send` or `recv`) + except (ConnectionResetError, socket.timeout): + ret = False + + finally: + self._sock.setblocking(True) + self._settimeout_impl(timeout) + + return ret + + def _shutdown_recv_impl(self): + """Close read """ - if target in ['write', 'send', 'stdin']: - self.sock.shutdown(socket.SHUT_WR) + self._sock.shutdown(socket.SHUT_RD) - elif target in ['read', 'recv', 'stdout', 'stderr']: - self.sock.shutdown(socket.SHUT_RD) + def _shutdown_send_impl(self): + """Close write + """ + self._sock.shutdown(socket.SHUT_WR) - else: - logger.error("You must specify `send` or `recv` as target.") + def __str__(self) -> str: + return f"{self._host}:{self._port}" - def is_alive(self, timeout: Optional[Union[int, float]]=None) -> bool: - try: - self._settimeout(timeout) - data = self.sock.recv(1, socket.MSG_PEEK) - return True - except BlockingIOError: - return False - except ConnectionResetError: - return False - except socket.timeout: - return False - def __del__(self): - self.close() + # + # Custom methods + # + @tube_is_open + def set_keepalive(self, + keep_idle: Optional[Union[int, float]]=None, + keep_interval: Optional[Union[int, float]]=None, + keep_count: Optional[Union[int, float]]=None): + """Set TCP keep-alive mode + + Send a keep-alive ping once every `keep_interval` seconds if activates + after `keep_idle` seconds of idleness, and closes the connection + after `keep_count` failed ping. + + Args: + keep_idle : Maximum duration to wait before sending keep-alive ping in second + keep_interval: Interval to send keep-alive ping in second + keep_count : Maximum number of failed attempts + """ + self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + if keep_idle is not None: + self._sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, keep_idle) + if keep_interval is not None: + self._sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, keep_interval) + if keep_count is not None: + self._sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, keep_count) + -# alias -remote = Socket +remote = Socket # alias diff --git a/ptrlib/connection/ssh.py b/ptrlib/connection/ssh.py index 8eb7cdd..3749087 100644 --- a/ptrlib/connection/ssh.py +++ b/ptrlib/connection/ssh.py @@ -1,20 +1,21 @@ -# coding: utf-8 import shlex import os from ptrlib.binary.encoding import * from ptrlib.arch.common import which from .proc import * -if os.name == 'nt': - _is_windows = True -else: - _is_windows = False +_is_windows = os.name == 'nt' -def SSH(host: str, port: int, username: str, - password: Optional[str]=None, identity: Optional[str]=None, - ssh_path: Optional[str]=None, expect_path: Optional[str]=None, - option: str='', command: str=''): +def SSH(host: str, + port: int, + username: str, + password: Optional[str]=None, + identity: Optional[str]=None, + ssh_path: Optional[str]=None, + expect_path: Optional[str]=None, + option: str='', + command: str=''): """Create an SSH shell Create a new process to connect to SSH server diff --git a/ptrlib/connection/tube.py b/ptrlib/connection/tube.py index ad20b65..36804f9 100644 --- a/ptrlib/connection/tube.py +++ b/ptrlib/connection/tube.py @@ -1,117 +1,213 @@ -# coding: utf-8 -import subprocess -from typing import Any, Optional, List, Tuple, Union, overload -try: - from typing import Literal -except: - from typing_extensions import Literal -from ptrlib.binary.encoding import * -from ptrlib.console.color import Color -from abc import ABCMeta, abstractmethod +import abc import re import sys import threading -import time from logging import getLogger +from typing import Callable, List, Literal, Optional, Tuple, Union +from ptrlib.binary.encoding import bytes2str, str2bytes, bytes2hex, bytes2utf8, hexdump, AnsiParser, AnsiInstruction +from ptrlib.console.color import Color logger = getLogger(__name__) +def tube_is_open(method): + """Ensure that connection is not *explicitly* closed + """ + def decorator(self, *args, **kwargs): + assert isinstance(self, Tube), "Invalid usage of decorator" + if self._is_closed: + raise BrokenPipeError("Connection has already been closed by `close`") + return method(self, *args, **kwargs) + return decorator + +def tube_is_send_open(method): + """Ensure that sender connection is not explicitly closed + """ + def decorator(self, *args, **kwargs): + assert isinstance(self, Tube), "Invalid usage of decorator" + if self._is_send_closed: + raise BrokenPipeError("Connection has already been closed by `shutdown`") + return method(self, *args, **kwargs) + return decorator + +def tube_is_recv_open(method): + """Ensure that receiver connection is not explicitly closed + """ + def decorator(self, *args, **kwargs): + assert isinstance(self, Tube), "Invalid usage of decorator" + if self._is_recv_closed: + raise BrokenPipeError("Connection has already been closed by `shutdown`") + return method(self, *args, **kwargs) + return decorator + + +class Tube(metaclass=abc.ABCMeta): + """Abstract class for streaming data + + A child class must implement the following methods: + + - "_settimeout_impl" + - "_recv_impl" + - "_send_impl" + - "_close_impl" + - "_is_alive_impl + - "_shutdown_recv_impl" + - "_shutdown_send_impl" + """ + def __new__(cls, *args, **kwargs): + cls._settimeout_impl = tube_is_open(cls._settimeout_impl) + cls._recv_impl = tube_is_recv_open(tube_is_open(cls._recv_impl)) + cls._send_impl = tube_is_send_open(tube_is_open(cls._send_impl)) + cls._close_impl = tube_is_open(cls._close_impl) + cls._is_alive_impl = tube_is_open(cls._is_alive_impl) + cls._shutdown_recv_impl = tube_is_recv_open(cls._shutdown_recv_impl) + cls._shutdown_send_impl = tube_is_send_open(cls._shutdown_send_impl) + return super().__new__(cls) + + # + # Constructor + # + def __init__(self, + timeout: Optional[Union[int, float]]=None, + debug: bool=False): + """Base constructor -class Tube(metaclass=ABCMeta): - def __init__(self): - self.buf = b'' - self.debug = False - - @abstractmethod - def _settimeout(self, timeout: Optional[Union[int, float]]): + Args: + timeout (float): Default timeout + """ + self._buffer = b'' + self._debug = debug + + self._is_closed = False + self._is_send_closed = False + self._is_recv_closed = False + + self._default_timeout = timeout + self.settimeout() + + # + # Properties + # + @property + def debug(self): + return self._debug + + @debug.setter + def debug(self, is_debug): + self._debug = bool(is_debug) + + # + # Methods + # + def settimeout(self, timeout: Optional[Union[int, float]]=None): """Set timeout - + Args: - timeout (float): Timeout (None: Set to default / -1: No change / x>0: Set timeout to x seconds) + timeout (float): Timeout in second + + Note: + Set timeout to None in order to set the default timeout) + + Examples: + ``` + p = Socket("0.0.0.0", 1337, timeout=3) + # ... + p.settimeout(5) # Timeout is set to 5 + # ... + p.settimeout() # Timeout is set to 3 + ``` """ - pass + assert timeout is None or (isinstance(timeout, (int, float)) and timeout >= 0), \ + "`timeout` must be positive and either int or float" - @abstractmethod - def _recv(self, size: int, timeout: Union[int, float]) -> Optional[bytes]: - """Receive raw data + if timeout is None: + if self._default_timeout is not None: + self._settimeout_impl(self._default_timeout) + else: + self._settimeout_impl(timeout) - Receive raw data of maximum `size` bytes length through the socket. + def recv(self, + size: int=4096, + timeout: Optional[Union[int, float]]=None) -> bytes: + """Receive data with buffering + + Receive raw data of at most `size` bytes. Args: - size (int): The data size to receive - timeout (int): Timeout (in second) + size : Size to receive (Use `recvonce` to read exactly `size` bytes) + timeout: Timeout in second Returns: - bytes: The received data - """ - pass - - def unget(self, data: Union[str, bytes]): - """Revert data to socket + bytes: Received data - Return data to socket. + Raises: + ConnectionAbortedError: Connection is aborted by process + ConnectionResetError: Connection is closed by peer + TimeoutError: Timeout exceeded + OSError: System error - Args: - data (bytes): Data to return + Examples: + ``` + tube.recv(4) + try: + tube.recv(timeout=3.14) + except TimeoutError: + pass + ``` """ - if isinstance(data, str): - data = str2bytes(data) - self.buf = data + self.buf + assert size is None or (isinstance(size, int) and size >= 0), \ + "`size` must be a positive integer" - def recv(self, size: int=4096, timeout: Optional[Union[int, float]]=None) -> bytes: - """Receive raw data with buffering + # NOTE: We always return buffer if it's not empty + # This is because we do not know how many bytes we can read. + if len(self._buffer): + data, self._buffer = self._buffer[:size], self._buffer[size:] + return data - Receive raw data of maximum `size` bytes length through the socket. + if timeout is not None: + self.settimeout(timeout) - Args: - size (int): The data size to receive (Use `recvonce` - if you want to read exactly `size` bytes) - timeout (int): Timeout (in second) + try: + data = self._recv_impl(size - len(self._buffer)) + if self._debug and len(data) > 0: + logger.info(f"Received {hex(len(data))} ({len(data)}) bytes:") + hexdump(data, prefix=" " + Color.CYAN, postfix=Color.END) - Returns: - bytes: The received data - """ - if size <= 0: - raise ValueError("`size` must be larger than 0") + self._buffer += data - elif len(self.buf) == 0: - self._settimeout(timeout) - try: - data = self._recv(size, timeout=-1) - except TimeoutError as err: - raise TimeoutError("`recv` timeout", b'') + except TimeoutError as err: + data = self._buffer + err.args[1] + self._buffer = b'' + raise TimeoutError("Timeout (recv)", data) - self.buf += data - if self.debug: - logger.info(f"Received {hex(len(data))} ({len(data)}) bytes:") - hexdump(data, prefix=" " + Color.CYAN, postfix=Color.END) + finally: + if timeout is not None: + # Reset timeout to default value + self.settimeout() - # We don't check size > len(self.buf) because Python handles it - data, self.buf = self.buf[:size], self.buf[size:] + data, self._buffer = self._buffer[:size], self._buffer[size:] return data - def recvonce(self, size: int, timeout: Optional[Union[int, float]]=None) -> bytes: - """Receive raw data with buffering + def recvonce(self, + size: int, + timeout: Optional[Union[int, float]]=None) -> bytes: + """Receive raw data of exact size with buffering - Receive raw data of size `size` bytes length through the socket. + Receive raw data of exactly `size` bytes. Args: - size (int): The data size to receive - timeout (int): Timeout (in second) + size : Data size to receive + timeout: Timeout in second Returns: - bytes: The received data + bytes: Received data """ - self._settimeout(timeout) data = b'' - timer_start = time.time() while len(data) < size: try: - data += self.recv(size - len(data), timeout=-1) + data += self.recv(size - len(data), timeout) except TimeoutError as err: - raise TimeoutError("`recvonce` timeout", data + err.args[1]) - time.sleep(0.01) + raise TimeoutError("Timeout (recvonce)", data + err.args[1]) if len(data) > size: self.unget(data[size:]) @@ -122,207 +218,274 @@ def recvuntil(self, size: int=4096, timeout: Optional[Union[int, float]]=None, drop: bool=False, - lookahead: bool=False, - sleep_time: float=0.01) -> bytes: + lookahead: bool=False) -> bytes: """Receive raw data until `delim` comes Args: - delim (bytes): The delimiter bytes - size (int) : The data size to receive at once - timeout (int): Timeout (in second) - drop (bool): Discard delimiter or not - lookahead (bool): Unget delimiter to buffer or not - sleep_time (float): Sleep time after receiving data + delim : The delimiter bytes + size : The data size to receive at once + timeout : Timeout in second + drop : Discard delimiter or not + lookahead: Unget delimiter to buffer or not Returns: bytes: Received data + + Raises: + ConnectionAbortedError: Connection is aborted by process + ConnectionResetError: Connection is closed by peer + TimeoutError: Timeout exceeded + OSError: System error + + Examples: + ``` + echo.sendline("abc123def") + echo.recvuntil("123") # abc123 + + echo.sendline("abc123def") + echo.recvuntil("123", drop=True) # abc + + echo.sendline("abc123def") + echo.recvuntil("123", lookahead=True) # abc123 + echo.recvonce(6) # 123def + ``` """ - # Validate and normalize delimiter - if isinstance(delim, bytes): - delim = [delim] - elif isinstance(delim, str): - delim = [str2bytes(delim)] - elif isinstance(delim, list): - for i, t in enumerate(delim): - if isinstance(t, str): - delim[i] = str2bytes(t) - elif not isinstance(t, bytes): - raise ValueError(f"Delimiter must be either string or bytes: {t}") + assert isinstance(delim, (str, bytes, list)), \ + "`delim` must be either str, bytes, or list" + + # Preprocess + if isinstance(delim, list): + for i, d in enumerate(delim): + assert isinstance(d, (str, bytes)), \ + f"`delim[{i}]` must be either str or bytes" + delim[i] = str2bytes(delim[i]) else: - raise ValueError(f"Delimiter must be either string, bytes, or list: {t}") + delim = [str2bytes(delim)] - self._settimeout(timeout) - data = b'' - timer_start = time.time() + if any(map(lambda d: len(d) == 0, delim)): + return b'' # Empty delimiter - found = False - token = None + # Iterate until we find one of the delimiters + found_delim = None + prev_len = 0 + data = b'' while True: try: - data += self.recv(size, timeout=-1) + data += self.recv(size, timeout) except TimeoutError as err: - raise TimeoutError("`recvuntil` timeout", data + err.args[1]) - - for t in delim: - if t in data: - found = True - token = t + raise TimeoutError("Timeout (recvuntil)", data + err.args[1]) + except Exception as err: + err.args = (err.args[0], data) + raise err from None + + for d in delim: + if d in data[max(0, prev_len-len(d)):]: + found_delim = d break - - if found: + if found_delim is not None: break - if sleep_time: - time.sleep(sleep_time) - found_pos = data.find(token) - result_len = found_pos if drop else found_pos + len(token) - consumed_len = found_pos if lookahead else found_pos + len(token) - self.unget(data[consumed_len:]) - return data[:result_len] + prev_len = len(data) + + i = data.find(found_delim) + j = i + len(found_delim) + if not drop: + i = j + + ret, data = data[:i], data[j:] + self.unget(data) + if lookahead: + self.unget(found_delim) + + return ret def recvline(self, size: int=4096, timeout: Optional[Union[int, float]]=None, - drop: bool=True) -> bytes: + drop: bool=True, + lookahead: bool=False) -> bytes: """Receive a line of data Args: - size (int) : The data size to receive at once - timeout (int): Timeout (in second) - drop (bool) : Discard delimiter or not + size : The data size to receive at once + timeout : Timeout (in second) + drop : Discard trailing newlines or not + lookahead: Unget trailing newline to buffer or not Returns: bytes: Received data """ - line = self.recvuntil(b'\n', size, timeout) - if drop: - return line.rstrip() - return line + try: + line = self.recvuntil(b'\n', size, timeout, lookahead=lookahead) + except TimeoutError as err: + raise TimeoutError("Timeout (recvline)", err.args[1]) + + return line.rstrip() if drop else line def recvlineafter(self, delim: Union[str, bytes], size: int=4096, timeout: Optional[Union[int, float]]=None, - drop: bool=True) -> bytes: + drop: bool=True, + lookahead: bool=False) -> bytes: """Receive a line of data after receiving `delim` Args: - delim (bytes): The delimiter bytes - size (int) : The data size to receive at once - timeout (int): Timeout (in second) - drop (bool) : Discard delimiter or not + delim : The delimiter bytes + size : The data size to receive at once + timeout : Timeout (in second) + drop : Discard trailing newline or not + lookahead: Unget trailing newline to buffer or not Returns: bytes: Received data - """ - self.recvuntil(delim, size, timeout) - return self.recvline(size, timeout, drop) - # TODO: proper typing - @overload - def recvregex(self, regex: Union[str, bytes], size: int=4096, discard: Literal[True]=True, timeout: Optional[Union[int, float]]=None) -> bytes: ... + Raises: + ConnectionAbortedError: Connection is aborted by process + ConnectionResetError: Connection is closed by peer + TimeoutError: Timeout exceeded + OSError: System error + """ + try: + self.recvuntil(delim, size, timeout) + except TimeoutError as err: + # NOTE: We do not set received value here + raise TimeoutError("Timeout (recvlineafter)", b'') - @overload - def recvregex(self, regex: Union[str, bytes], size: int=4096, discard: Literal[False]=False, timeout: Optional[Union[int, float]]=None) -> Tuple[bytes, bytes]: ... + try: + return self.recvline(size, timeout, drop, lookahead) + except TimeoutError as err: + raise TimeoutError("Timeout (recvlineafter)", err.args[1]) def recvregex(self, - regex: Union[str, bytes], + regex: Union[str, bytes, re.Pattern], size: int=4096, - discard: bool=True, - timeout: Optional[Union[int, float]]=None) -> Union[bytes, Tuple[bytes, bytes]]: + timeout: Optional[Union[int, float]]=None) -> Union[bytes, Tuple[bytes, ...]]: """Receive until a pattern comes Receive data until a specified regex pattern matches. Args: - regex (bytes) : Regex - size (int) : Size to read at once - discard (bool): Discard received bytes or not - timeout (int) : Timeout (in second) + regex : Regular expression + size : Size to read at once + timeout: Timeout in second Returns: tuple: If the given regex has multiple patterns to find, it returns all matches. Otherwise, it returns the - match string. If discard is false, it also returns - all data received so far along with the matches. + matched string. + + Raises: + ConnectionAbortedError: Connection is aborted by process + ConnectionResetError: Connection is closed by peer + TimeoutError: Timeout exceeded + OSError: System error """ - if not isinstance(regex, bytes): - regex = str2bytes(regex) + assert isinstance(regex, (str, bytes, re.Pattern)), \ + "`regex` must be either str, bytes, or re.Pattern" - p = re.compile(regex) - data = b'' + if isinstance(regex, str): + regex = re.compile(str2bytes(regex)) - self._settimeout(timeout) - r = None - while r is None: - data += self.recv(size, timeout=-1) - r = p.search(data) + data = b'' + match = None + while match is None: + try: + data += self.recv(size, timeout) + except TimeoutError as err: + raise TimeoutError("Timeout (recvregex)", data + err.args[1]) + match = regex.search(data) - pos = r.end() - self.unget(data[pos:]) + self.unget(data[match.end():]) - group = r.group() - groups = r.groups() - if groups: - if discard: - return groups - else: - return groups, data[:pos] + if match.groups(): + return match.groups() else: - if discard: - return group - else: - return group, data[:pos] + return match.group() - def recvscreen(self, delim: Optional[bytes]=b'\x1b[H', - returns: Optional[type]=str, - timeout: Optional[Union[int, float]]=None, - timeout2: Optional[Union[int, float]]=1): + def recvscreen(self, + returns: type=str, + stop: Optional[Callable[[AnsiInstruction], bool]]=None, + timeout: Union[int, float]=1.0): """Receive a screen - Receive a screen drawn by ncurses + Receive a screen drawn by ncurses (ANSI escape sequence) Args: - delim (bytes) : Refresh sequence - returns (type): Return value as string or list - timeout (int) : Timeout to receive the first delimiter - timeout2 (int): Timeout to receive the second delimiter + returns: Either str or list + stop: Function to determine when to stop emulating instructions + timeout: Timeout until stopping recv Returns: str: Rectangle string drawing the screen """ - self.recvuntil(delim, timeout=timeout) - try: - buf = self.recvuntil(delim, drop=True, lookahead=True, - timeout=timeout2) - except TimeoutError as err: - buf = err.args[1] - screen = draw_ansi(buf) - - if returns == list: - return screen - elif returns == str: - return '\n'.join(map(lambda row: ''.join(row), screen)) - elif returns == bytes: - return b'\n'.join(map(lambda row: bytes(row), screen)) - else: - raise TypeError("`returns` must be either list, str, or bytes") + assert returns in [list, str, bytes], \ + "`returns` must be either list or str" - @abstractmethod - def _send(self, data: bytes): - pass + def _ansi_stream(self): + """Generator for recvscreen + """ + while True: + try: + yield self.recv(timeout=timeout) + except TimeoutError as e: + self.unget(e.args[1]) + break + + ansi = AnsiParser(_ansi_stream(self)) + scr = ansi.draw_screen(returns, stop) + self.unget(ansi.buffer) + return scr + + def send(self, data: Union[str, bytes]) -> int: + """Send raw data + + Send as much data as possible. + + Args: + data: Data to send + + Returns: + int: Length of sent data + + Note: + It is NOT ensured that all data is sent. + Use `sendonce` to make sure the whole data is sent. + + Examples: + ``` + tube.send("Hello") + tube.send(b"\xde\xad\xbe\xef") + ``` + """ + assert isinstance(data, (str, bytes)), "`data` must be either str or bytes" - def send(self, data: bytes): - self._send(data) + size = self._send_impl(str2bytes(data)) if self.debug: - logger.info(f"Sent {hex(len(data))} ({len(data)}) bytes:") - hexdump(data, prefix=Color.YELLOW, postfix=Color.END) + logger.info(f"Sent {hex(size)} ({size}) bytes:") + hexdump(data[:size], prefix=Color.YELLOW, postfix=Color.END) - @abstractmethod - def _socket(self) -> Optional[Any]: - pass + return size + + def sendall(self, data: Union[str, bytes]): + """Send the whole data - def sendline(self, data: Union[str, bytes], timeout: Optional[Union[int, float]]=None): + Send the whole data. + This method will never return until it finishes sending + the whole data, unlike `send`. + + Args: + data: Data to send + """ + to_send = len(data) + while to_send > 0: + sent = self.send(data) + data = data[sent:] + to_send -= sent + + def sendline(self, + data: Union[int, float, str, bytes], + timeout: Optional[Union[int, float]]=None): """Send a line Send a line of data. @@ -331,37 +494,56 @@ def sendline(self, data: Union[str, bytes], timeout: Optional[Union[int, float]] data (bytes) : Data to send timeout (int): Timeout (in second) """ - if isinstance(data, str): - data = str2bytes(data) - elif isinstance(data, int): + assert isinstance(data, (int, float, str, bytes)), \ + "`data` must be int, float, str, or bytes" + + if isinstance(data, (int, float)): data = str(data).encode() + else: + data = str2bytes(data) self.send(data + b'\n') - def sendafter(self, delim: Union[str, bytes], data: Union[str, bytes, int], timeout: Optional[Union[int, float]]=None): + def sendafter(self, + delim: Union[str, bytes, List[Union[str, bytes]]], + data: Union[int, float, str, bytes], + size: int=4096, + timeout: Optional[Union[int, float]]=None, + drop: bool=False, + lookahead: bool=False) -> bytes: """Send raw data after a delimiter Send raw data after `delim` is received. Args: - delim (bytes): The delimiter - data (bytes) : Data to send - timeout (int): Timeout (in second) + delim : The delimiter + data : Data to send + size : Data size to receive at once + timeout : Timeout in second + drop : Discard delimiter or not + lookahead: Unget delimiter to buffer or not Returns: bytes: Received bytes before `delim` comes. - """ - if isinstance(data, str): - data = str2bytes(data) - elif isinstance(data, int): - data = str(data).encode() - recv_data = self.recvuntil(delim, timeout=timeout) + Examples: + ``` + tube.sendafter("> ", p32(len(data)) + data) + tube.sendafter("command: ", 1) # b"1" is sent + ``` + """ + recv_data = self.recvuntil(delim, size, timeout, drop, lookahead) self.send(data) return recv_data - def sendlineafter(self, delim: Union[str, bytes], data: Union[str, bytes, int], timeout: Optional[Union[int, float]]=None) -> bytes: + def sendlineafter(self, + delim: Union[str, bytes], + data: Union[str, bytes, int], + size: int=4096, + timeout: Optional[Union[int, float]]=None, + drop: bool=False, + lookahead: bool=False) -> bytes: """Send raw data after a delimiter Send raw data with newline after `delim` is received. @@ -374,12 +556,7 @@ def sendlineafter(self, delim: Union[str, bytes], data: Union[str, bytes, int], Returns: bytes: Received bytes before `delim` comes. """ - if isinstance(data, str): - data = str2bytes(data) - elif isinstance(data, int): - data = str(data).encode() - - recv_data = self.recvuntil(delim, timeout=timeout) + recv_data = self.recvuntil(delim, size, timeout, drop, lookahead) self.sendline(data, timeout=timeout) return recv_data @@ -390,7 +567,7 @@ def sendctrl(self, name: str): Send control key given its name Args: - name (str): Name of the control key to send + name: Name of the control key to send """ if name.lower() in ['w', 'up']: self.send(b'\x1bOA') @@ -409,97 +586,253 @@ def sendctrl(self, name: str): else: raise ValueError(f"Invalid control key name: {name}") - def sh(self, timeout: Optional[Union[int, float]]=None): + def sh(self, + prompt: str="[ptrlib]$ ", + raw: bool=False): """Alias for interactive + + Args: + prompt: Prompt string to show on input + raw : Escape non-printable characters or not """ - self.interactive(timeout) + self.interactive(prompt, raw) - def interactive(self, timeout: Optional[Union[int, float]]=None): + def interactive(self, + prompt: str="[ptrlib]$ ", + raw: bool=False): """Interactive mode + + Args: + prompt: Prompt string to show on input + raw : Escape non-printable characters or not """ - def thread_recv(): - prev_leftover = None + prompt = f"{Color.BOLD}{Color.BLUE}{prompt}{Color.END}" + + def pretty_print_hex(c: str): + sys.stdout.write(f'{Color.RED}\\x{ord(c):02x}{Color.END}') + + def pretty_print(data: bytes, prev: bytes=b''): + """Print data in a human-friendly way + """ + leftover = b'' + + if raw: + sys.stdout.write(bytes2str(data)) + + else: + utf8str, leftover, marker = bytes2utf8(data) + if len(utf8str) == 0 and prev == leftover: + utf8str = f'{Color.RED}{bytes2hex(leftover)}{Color.END}' + leftover = b'' + + for c, t in zip(utf8str, marker): + if t: + if 0x7f <= ord(c) < 0x100: + pretty_print_hex(c) + elif ord(c) in [0x00]: # TODO: What is printable? + pretty_print_hex(c) + else: + sys.stdout.write(c) + else: + pretty_print_hex(c) + + sys.stdout.flush() + return leftover + + def thread_recv(flag: threading.Event): + """Receive data from tube and print to stdout + """ + leftover = b'' while not flag.isSet(): try: - data = self.recv(size=4096, timeout=0.1) - if data is not None: - utf8str, leftover, marker = bytes2utf8(data) - if len(utf8str) == 0 and prev_leftover == leftover: - # Print raw hex string with color - # if the data is invalid as UTF-8 - utf8str = '{red}{hexstr}{end}'.format( - red=Color.RED, - hexstr=bytes2hex(leftover), - end=Color.END - ) - leftover = None - - for c, t in zip(utf8str, marker): - if t == True: - sys.stdout.write(c) - else: - sys.stdout.write('{red}{hexstr}{end}'.format( - red=Color.RED, - hexstr=str2hex(c), - end=Color.END - )) - sys.stdout.flush() - prev_leftover = leftover - if leftover is not None: - self.unget(leftover) + sys.stdout.write(prompt) + sys.stdout.flush() + data = self.recv() + leftover = pretty_print(data, leftover) + + if not self.is_alive(): + logger.error(f"Connection closed by {str(self)}") + flag.set() + except TimeoutError: - pass + pass # NOTE: We can ignore args since recv will never buffer except EOFError: logger.error("Receiver EOF") break except ConnectionAbortedError: logger.error("Receiver EOF") break - time.sleep(0.1) - flag = threading.Event() - th = threading.Thread(target=thread_recv) - th.setDaemon(True) - th.start() - - try: + def thread_send(flag: threading.Event): + """Read user input and send it to tube + """ + #sys.stdout.write(f"{Color.BOLD}{Color.BLUE}{prompt}{Color.END}") + #sys.stdout.flush() while not flag.isSet(): - data = input("{bold}{blue}[ptrlib]${end} ".format( - bold=Color.BOLD, blue=Color.BLUE, end=Color.END - )) - if self._socket() is None: - logger.error("Connection already closed") - break - if data is None: + try: + self.send(sys.stdin.readline()) + except (ConnectionResetError, ConnectionAbortedError, OSError): flag.set() - else: - try: - self.sendline(data) - except ConnectionAbortedError: - logger.error("Sender EOF") - break - time.sleep(0.1) + + flag = threading.Event() + th_recv = threading.Thread(target=thread_recv, args=(flag,)) + th_send = threading.Thread(target=thread_send, args=(flag,)) + th_recv.start() + th_send.start() + try: + th_recv.join() + th_send.join() except KeyboardInterrupt: + logger.warning("Intterupted by user") + sys.stdin.close() flag.set() - while th.is_alive(): - th.join(timeout = 0.1) - time.sleep(0.1) + def close(self): + """Close this connection + + Note: + This method can only be called once. + """ + self._close_impl() + self._is_closed = True + + def unget(self, data: Union[str, bytes]): + """Unshift data to buffer + + Args: + data: Data to revert + + Examples: + ``` + leak = tube.recvline().rstrip(b"> ") + tube.unget("> ") + # ... + tube.sendlineafter("> ", "1") + ``` + """ + assert isinstance(data, (str, bytes)), "`data` must be either str or bytes" + + self._buffer = str2bytes(data) + self._buffer + + def is_alive(self) -> bool: + """Check if connection is not closed + + Returns: + bool: False if connection is closed, otherwise True + + Examples: + ``` + while tube.is_alive(): + print(tube.recv()) + ``` + """ + if self._is_closed: + return False + return self._is_alive_impl() + + def shutdown(self, target: Literal['send', 'recv']): + """Kill one connection + + Args: + target (str): Connection to close (`send` or `recv`) + + Examples: + The following code shuts down input of remote. + ``` + tube.shutdown("send") + data = tube.recv() # OK + tube.send(b"data") # NG + ``` + + The following code shuts down output of remote. + ``` + tube.shutdown("recv") + tube.send(b"data") # OK + data = tube.recv() # NG + ``` + """ + if target in ['write', 'send', 'stdin']: + self._shutdown_send_impl() + self._is_send_closed = True + elif target in ['read', 'recv', 'stdout', 'stderr']: + self._shutdown_recv_impl() + self._is_recv_closed = True + else: + raise ValueError("`target` must either 'send' or 'recv'") def __enter__(self): return self - def __exit__(self, e_type, e_value, traceback): - self.close() + def __exit__(self, _e_type, _e_value, _traceback): + if not self._is_closed: + self.close() - @abstractmethod - def is_alive(self) -> bool: + def __str__(self) -> str: + return "" + + def __del__(self): + if hasattr(self, '_init_done') and not self._is_closed: + self.close() + + # + # Abstract methods + # + @abc.abstractmethod + def _settimeout_impl(self, timeout: Union[int, float]): + """Abstract method for `settimeout` + + Set timeout for receive and send. + + Args: + timeout: Timeout in second + """ pass - @abstractmethod - def close(self): + @abc.abstractmethod + def _recv_impl(self, size: int) -> bytes: + """Abstract method for `recv` + + Receives at most `size` bytes from tube. + This method must be a blocking method. + """ pass - @abstractmethod - def shutdown(self, target: Literal['send', 'recv']): + @abc.abstractmethod + def _send_impl(self, data: bytes) -> int: + """Abstract method for `send` + + Sends tube as much data as possible. + + Args: + data: Data to send + """ + pass + + @abc.abstractmethod + def _close_impl(self): + """Abstract method for `close` + + Close the connection. + This method is ensured to be called only once. + """ + pass + + @abc.abstractmethod + def _is_alive_impl(self) -> bool: + """Abstract method for `is_alive` + + This method must return True iff the connection is alive. + """ + pass + + @abc.abstractmethod + def _shutdown_recv_impl(self): + """Kill receiver connection + """ + pass + + @abc.abstractmethod + def _shutdown_send_impl(self): + """Kill sender connection + """ pass diff --git a/ptrlib/connection/winproc.py b/ptrlib/connection/winproc.py index d2f6561..390d480 100644 --- a/ptrlib/connection/winproc.py +++ b/ptrlib/connection/winproc.py @@ -1,16 +1,16 @@ -# coding: utf-8 from logging import getLogger -from typing import List, Mapping -from ptrlib.binary.encoding import * -from .tube import * -import ctypes +from typing import List, Mapping, Optional, Union import os -import time +import subprocess +from ptrlib.binary.encoding import bytes2str +from .tube import Tube _is_windows = os.name == 'nt' if _is_windows: + import pywintypes import win32api import win32con + import win32event import win32file import win32pipe import win32process @@ -18,233 +18,282 @@ logger = getLogger(__name__) - class WinPipe(object): - def __init__(self, inherit_handle: bool=True): + def __init__(self, + read: Optional[bool]=False, + write: Optional[bool]=False, + size: Optional[int]=65536): """Create a pipe for Windows - Create a new pipe - - Args: - inherit_handle (bool): Whether the child can inherit this handle - - Returns: - WinPipe: ``WinPipe`` instance. - """ - attr = win32security.SECURITY_ATTRIBUTES() - attr.bInheritHandle = inherit_handle - self.rp, self.wp = win32pipe.CreatePipe(attr, 0) - - @property - def handle0(self) -> int: - return self.get_handle('recv') - @property - def handle1(self) -> int: - return self.get_handle('send') - - def get_handle(self, name: Literal['recv', 'send']='recv') -> int: - """Get endpoint of this pipe + Create a new pipe with overlapped I/O. Args: - name (str): Handle to get (`recv` or `send`) + read: True if read mode + write: True if write mode + size: Default buffer size for this pipe + timeout: Default timeout in second """ - if name in ['read', 'recv', 'stdin']: - return self.rp - - elif name in ['write', 'send', 'stdout', 'stderr']: - return self.wp - + if read and write: + mode = win32pipe.PIPE_ACCESS_DUPLEX + self._access = win32con.GENERIC_READ | win32con.GENERIC_WRITE + elif write: + mode = win32pipe.PIPE_ACCESS_OUTBOUND + self._access = win32con.GENERIC_READ else: - logger.error("You must specify `send` or `recv` as target.") - - @property - def size(self) -> int: - """Get the number of bytes available to read on this pipe""" - # (lpBytesRead, lpTotalBytesAvail, lpBytesLeftThisMessage) - return win32pipe.PeekNamedPipe(self.handle0, 0)[1] - - def _recv(self, size: int=4096): - if size <= 0: - logger.error("`size` must be larger than 0") - return b'' - - buf = ctypes.create_string_buffer(size) - win32file.ReadFile(self.handle0, buf) - - return buf.raw - - def recv(self, size: int=4096, timeout: Optional[Union[int, float]]=None): - """Receive raw data + mode = win32pipe.PIPE_ACCESS_INBOUND + self._access = win32con.GENERIC_WRITE - Receive raw data of maximum `size` bytes length through the pipe. + self._attr = win32security.SECURITY_ATTRIBUTES() + self._attr.bInheritHandle = True - Args: - size (int): The data size to receive - timeout (int): Timeout (in second) - - Returns: - bytes: The received data - """ - start = time.time() - # Wait until data arrives - while self.size == 0: - # Check timeout - if timeout is not None and time.time() - start > timeout: - raise TimeoutError("Receive timeout") - time.sleep(0.01) - - return self._recv(min(self.size, size)) - - def send(self, data: bytes): - """Send raw data + self._name = f"\\\\.\\pipe\\ptrlib.{os.getpid()}.{os.urandom(8).hex()}" + self._handle = win32pipe.CreateNamedPipe( + self._name, mode | win32file.FILE_FLAG_OVERLAPPED, + win32pipe.PIPE_TYPE_BYTE | win32pipe.PIPE_READMODE_BYTE | win32pipe.PIPE_WAIT, + 1, size, size, 0, self._attr + ) + assert self._handle != win32file.INVALID_HANDLE_VALUE, \ + "Could not create a pipe" - Send raw data through the socket + @property + def name(self) -> str: + return self._name + + @property + def access(self) -> int: + return self._access + + @property + def attributes(self) -> any: + return self._attr - Args: - data (bytes) : Data to send - timeout (int): Timeout (in second) - """ - win32file.WriteFile(self.handle1, data) + @property + def handle(self) -> int: + return self._handle def close(self): - """Cleanly close this pipe""" - win32api.CloseHandle(self.rp) - win32api.CloseHandle(self.wp) + """Gracefully close this pipe + """ + win32api.CloseHandle(self._handle) def __del__(self): self.close() class WinProcess(Tube): - def __init__(self, args: Union[List[Union[str, bytes]], str], env: Optional[Mapping[str, str]]=None, cwd: Optional[str]=None, flags: int=0, timeout: Optional[Union[int, float]]=None): - """Create a process - - Create a new process and make a pipe. + # + # Constructor + # + def __init__(self, + args: Union[List[Union[str, bytes]], str], + env: Optional[Union[Mapping[bytes, Union[bytes, str]], Mapping[str, Union[bytes, str]]]]=None, + cwd: Optional[Union[bytes, str]]=None, + flags: int = 0, + raw: bool=False, + stdin : Optional[WinPipe]=None, + stdout: Optional[WinPipe]=None, + stderr: Optional[WinPipe]=None, + **kwargs): + """Create a Windows process + + Create a Windows process and make a pipe. Args: - args (list): The arguments to pass - env (list) : The environment variables + args : The arguments to pass + env : The environment variables + cwd : Working directory + flags : dwCreationFlags passed to CreateProcess + raw : Disable pty if this parameter is true + stdin : File descriptor of standard input + stdout : File descriptor of standard output + stderr : File descriptor of standard error Returns: - Process: ``Process`` instance. + WinProcess: ``WinProcess`` instance + + Examples: + ``` + p = Process("cmd.exe", cwd="C:\\") + p = Process(["cmd", "dir"], + stderr=subprocess.DEVNULL) + p = Process("more C:\\test.txt", env={"X": "123"}) + ``` """ - assert _is_windows - super().__init__() + assert _is_windows, "WinProcess cannot work on Unix" + assert isinstance(args, (str, bytes, list)), \ + "`args` must be either str, bytes, or list" + assert env is None or isinstance(env, dict), \ + "`env` must be a dictionary" + assert cwd is None or isinstance(cwd, (str, bytes)), \ + "`cwd` must be either str or bytes" + + self._current_timeout = 0 + super().__init__(**kwargs) if isinstance(args, list): for i, arg in enumerate(args): if isinstance(arg, bytes): args[i] = bytes2str(arg) - self.args = ' '.join(args) - self.filepath = args[0] - - # Check if arguments are safe for Windows - for arg in args: - if '"' not in arg: continue - if arg[0] == '"' and arg[-1] == '"': continue - logger.error("You have to escape the arguments by yourself.") - logger.error("Be noted what you are executing is") - logger.error("> " + self.args) + args = subprocess.list2cmdline(args) else: - self.args = args + args = bytes2str(args) - # Create pipe - self.stdin = WinPipe() - self.stdout = WinPipe() - self.default_timeout = timeout - self.timeout = timeout - self.proc = None + self._filepath = args - # Create process + # Prepare stdio + if stdin is None: + self._stdin = WinPipe(write=True) + proc_stdin = win32file.CreateFile( + self._stdin.name, self._stdin.access, + 0, self._stdin.attributes, + win32con.OPEN_EXISTING, win32file.FILE_ATTRIBUTE_NORMAL, None + ) + + if stdout is None: + self._stdout = WinPipe(read=True) + proc_stdout = win32file.CreateFile( + self._stdout.name, self._stdout.access, + 0, self._stdout.attributes, + win32con.OPEN_EXISTING, win32file.FILE_ATTRIBUTE_NORMAL, None + ) + + if stderr is None: + self._stderr = self._stdout + proc_stderr = proc_stdout + else: + proc_stderr = win32file.CreateFile( + self._stderr.name, self._stderr.access, + 0, self._stderr.attributes, + win32con.OPEN_EXISTING, win32file.FILE_ATTRIBUTE_NORMAL, None + ) + + # Open process info = win32process.STARTUPINFO() info.dwFlags = win32con.STARTF_USESTDHANDLES - info.hStdInput = self.stdin.handle0 - info.hStdOutput = self.stdout.handle1 - info.hStdError = self.stdout.handle1 - # (hProcess, hThread, dwProcessId, dwThreadId) - self.proc, _, self.pid, _ = win32process.CreateProcess( - None, self.args, # lpApplicationName, lpCommandLine - None, None, # lpProcessAttributes, lpThreadAttributes - True, flags, # bInheritHandles, dwCreationFlags - env, cwd, # lpEnvironment, lpCurrentDirectory - info # lpStartupInfo + info.hStdInput = proc_stdin + info.hStdOutput = proc_stdout + info.hStdError = proc_stderr + self._proc, _, self._pid, _ = win32process.CreateProcess( + None, args, None, None, True, flags, env, cwd, info ) - logger.info("Successfully created new process (PID={})".format(self.pid)) + win32file.CloseHandle(proc_stdin) + win32file.CloseHandle(proc_stdout) + if proc_stdout != proc_stderr: + win32file.CloseHandle(proc_stderr) - def _settimeout(self, timeout: Optional[Union[int, float]]): - """Set timeout value""" - if timeout is None: - self.timeout = self.default_timeout - elif timeout > 0: - self.timeout = timeout + # Wait until connection + win32pipe.ConnectNamedPipe(self._stdin.handle) + win32pipe.ConnectNamedPipe(self._stdout.handle) + win32pipe.ConnectNamedPipe(self._stderr.handle) - def _socket(self): - return self.proc + self._returncode = None - def _recv(self, size: int, timeout: Optional[Union[int, float]]=None) -> bytes: - """Receive raw data + logger.info(f"Successfully created new process {str(self)}") - Receive raw data of maximum `size` bytes length through the pipe. + # + # Property + # + @property + def returncode(self) -> Optional[int]: + return self._returncode - Args: - size (int): The data size to receive - timeout (int): Timeout (in second) + @property + def pid(self) -> int: + return self._pid - Returns: - bytes: The received data + # + # Implementation of Tube + # + def _settimeout_impl(self, timeout: Union[int, float]): + """Set timeout + + Args: + timeout: Timeout in second (Maximum precision is millisecond) """ - self._settimeout(timeout) - if size <= 0: - logger.error("`size` must be larger than 0") - return b'' + self._current_timeout = timeout - buf = self.stdout.recv(size, self.timeout) - return buf + def _recv_impl(self, size: int) -> bytes: + """Receive raw data - def is_alive(self) -> bool: - """Check if process is alive + Args: + size: Size to receive Returns: - bool: True if process is alive, otherwise False + bytes: Received data """ - if self.proc is None: - return False - else: - status = win32process.GetExitCodeProcess(self.proc) - return status == win32con.STILL_ACTIVE + if self._current_timeout == 0: + # Without timeout + try: + _, data = win32file.ReadFile(self._stdout.handle, size) + return data + except Exception as err: + raise err from None - def close(self): - if self.proc: - win32api.TerminateProcess(self.proc, 0) - win32api.CloseHandle(self.proc) - self.proc = None - logger.info("Process killed (PID={0})".format(self.pid)) - - def _send(self, data: bytes): + else: + # With timeout + overlapped = pywintypes.OVERLAPPED() + overlapped.hEvent = win32event.CreateEvent(None, 0, 0, None) + try: + _, data = win32file.ReadFile(self._stdout.handle, size, overlapped) + state = win32event.WaitForSingleObject( + overlapped.hEvent, int(self._current_timeout * 1000) + ) + if state == win32event.WAIT_OBJECT_0: + result = win32file.GetOverlappedResult(self._stdout.handle, overlapped, True) + if isinstance(result, int): + # NOTE: GetOverlappedResult does not return data + # when overlapped ReadFile is successful. + # We need to use the result of this API because + # we cannot access the number of bytes read by ReadFile. + # See https://github.com/mhammond/pywin32/issues/430 + return data[:result] + else: + return result[1] + else: + raise TimeoutError("Timeout (_recv_impl)", b'') + finally: + win32file.CloseHandle(overlapped.hEvent) + + def _send_impl(self, data: bytes) -> int: """Send raw data - Send raw data through the socket - Args: - data (bytes) : Data to send - """ - self.stdin.send(data) + data: Data to send - def shutdown(self, target: Literal['send', 'recv']): - """Close a connection - - Args: - target (str): Pipe to close (`recv` or `send`) + Returns: + int: The number of bytes written """ - if target in ['write', 'send', 'stdin']: - self.stdin.close() - - elif target in ['read', 'recv', 'stdout', 'stderr']: - self.stdout.close() + _, n = win32file.WriteFile(self._stdin.handle, data) + return n + + def _close_impl(self): + win32api.TerminateProcess(self._proc, 0) + win32api.CloseHandle(self._proc) + logger.info(f"Process killed {str(self)}") + + def _is_alive_impl(self) -> bool: + """Check if process is alive + Returns: + bool: True if process is alive, otherwise False + """ + status = win32process.GetExitCodeProcess(self._proc) + if status == win32con.STILL_ACTIVE: + return True else: - logger.error("You must specify `send` or `recv` as target.") + self._returncode = status + return False + + def _shutdown_recv_impl(self): + """Kill receiver connection + """ + self._stdout.close() + + def _shutdown_send_impl(self): + """Kill sender connection + """ + self._stdin.close() - def __del__(self): - self.close() + def __str__(self) -> str: + return f'{self._filepath} (PID={self._pid})' diff --git a/tests/binary/encoding/test_locale.py b/tests/binary/encoding/test_locale.py index 476c4dc..c82d9fb 100644 --- a/tests/binary/encoding/test_locale.py +++ b/tests/binary/encoding/test_locale.py @@ -1,5 +1,5 @@ import unittest -from ptrlib.binary.encoding.locale import * +from ptrlib.binary.encoding.char import * from logging import getLogger, FATAL diff --git a/tests/connection/test_proc.py b/tests/connection/test_proc.py index c6d5f3f..e4d017c 100644 --- a/tests/connection/test_proc.py +++ b/tests/connection/test_proc.py @@ -26,7 +26,7 @@ def test_basic(self): with self.assertLogs(module_name) as cm: p = Process("./tests/test.bin/test_echo.x64") self.assertEqual(len(cm.output), 1) - self.assertRegex(cm.output[0], fr'^INFO:{module_name}:Successfully created new process \(PID=\d+\)$') + self.assertEqual(cm.output[0], f'INFO:{module_name}:Successfully created new process {str(p)}') # sendline / recvline p.sendline(b"Message : " + msg) @@ -70,7 +70,7 @@ def test_basic(self): with self.assertLogs(module_name) as cm: p.close() self.assertEqual(len(cm.output), 1) - self.assertRegex(cm.output[0], fr'^INFO:{module_name}:.+ \(PID=\d+\) has already exited$') + self.assertEqual(cm.output[0], fr'INFO:{module_name}:{str(p)} stopped with exit code 0') def test_timeout(self): module_name = inspect.getmodule(Process).__name__ @@ -78,7 +78,7 @@ def test_timeout(self): with self.assertLogs(module_name) as cm: p = Process("./tests/test.bin/test_echo.x64") self.assertEqual(len(cm.output), 1) - self.assertRegex(cm.output[0], fr'^INFO:{module_name}:Successfully created new process \(PID=\d+\)$') + self.assertEqual(cm.output[0], fr'INFO:{module_name}:Successfully created new process {str(p)}') data = os.urandom(16).hex() # recv diff --git a/tests/connection/test_sock.py b/tests/connection/test_sock.py index 73cd431..e5b35ea 100644 --- a/tests/connection/test_sock.py +++ b/tests/connection/test_sock.py @@ -29,18 +29,20 @@ def test_timeout(self): sock = Socket("www.example.com", 80) sock.sendline(b'GET / HTTP/1.1\r') sock.send(b'Host: www.example.com\r\n\r\n') - try: - sock.recvuntil("*** never expected ***", timeout=1) - result = False - except TimeoutError as err: - self.assertEqual(b"200 OK" in err.args[1], True) - result = True - except: - result = False - finally: - sock.close() - self.assertEqual(result, True) + with self.assertRaises(TimeoutError) as cm: + sock.recvuntil("*** never expected ***", timeout=2) + self.assertEqual(b"200 OK" in cm.exception.args[1], True) + + def test_reset(self): + sock = Socket("www.example.com", 80) + sock.sendline(b'GET / HTTP/1.1\r') + sock.send(b'Host: www.example.com\r\n') + sock.send(b'Connection: close\r\n\r\n') + + with self.assertRaises(ConnectionResetError) as cm: + sock.recvuntil("*** never expected ***", timeout=2) + self.assertEqual(b"200 OK" in cm.exception.args[1], True) def test_tls(self): host = "www.example.com" diff --git a/tests/connection/test_windows_proc.py b/tests/connection/test_windows_proc.py index 5282a3f..66ce000 100644 --- a/tests/connection/test_windows_proc.py +++ b/tests/connection/test_windows_proc.py @@ -27,7 +27,7 @@ def test_basic(self): with self.assertLogs(module_name) as cm: p = Process("./tests/test.bin/test_echo.pe.exe") self.assertEqual(len(cm.output), 1) - self.assertRegex(cm.output[0], fr'^INFO:{module_name}:Successfully created new process \(PID=\d+\)$') + self.assertEqual(cm.output[0], f'INFO:{module_name}:Successfully created new process {str(p)}') pid = p.pid # send / recv @@ -58,7 +58,7 @@ def test_timeout(self): with self.assertLogs(module_name) as cm: p = Process("./tests/test.bin/test_echo.pe.exe") self.assertEqual(len(cm.output), 1) - self.assertRegex(cm.output[0], fr'^INFO:{module_name}:Successfully created new process \(PID=\d+\)$') + self.assertEqual(cm.output[0], f'INFO:{module_name}:Successfully created new process {str(p)}') with self.assertRaises(TimeoutError): p.recvuntil("*** never expected ***", timeout=1) diff --git a/tests/pwn/test_fsb.py b/tests/pwn/test_fsb.py index a4fcf30..d38d505 100644 --- a/tests/pwn/test_fsb.py +++ b/tests/pwn/test_fsb.py @@ -23,7 +23,7 @@ def test_fsb32(self): with self.assertLogs(module_name) as cm: p = Process("./tests/test.bin/test_fsb.x86") self.assertEqual(len(cm.output), 1) - self.assertRegex(cm.output[0], fr'^INFO:{module_name}:Successfully created new process \(PID=\d+\)$') + self.assertEqual(cm.output[0], f'INFO:{module_name}:Successfully created new process {str(p)}') p.recvuntil(": ") target = int(p.recvline(), 16) payload = fsb( @@ -42,7 +42,7 @@ def test_fsb32(self): with self.assertLogs(module_name) as cm: p = Process("./tests/test.bin/test_fsb.x86") self.assertEqual(len(cm.output), 1) - self.assertRegex(cm.output[0], fr'^INFO:{module_name}:Successfully created new process \(PID=\d+\)$') + self.assertEqual(cm.output[0], f'INFO:{module_name}:Successfully created new process {str(p)}') p.recvuntil(": ") target = int(p.recvline(), 16) payload = fsb( @@ -65,7 +65,7 @@ def test_fsb64(self): with self.assertLogs(module_name) as cm: p = Process("./tests/test.bin/test_fsb.x64") self.assertEqual(len(cm.output), 1) - self.assertRegex(cm.output[0], fr'^INFO:{module_name}:Successfully created new process \(PID=\d+\)$') + self.assertEqual(cm.output[0], f'INFO:{module_name}:Successfully created new process {str(p)}') p.recvuntil(": ") target = int(p.recvline(), 16) payload = fsb(